Unverified Commit 305d5c16 authored by espylapiza's avatar espylapiza Committed by GitHub
Browse files

[Example] update ogbn-arxiv and ogbn-proteins results (#3018)



* [Example] GCN on ogbn-arxiv dataset

* Add README.md

* Update GCN implementation on ogbn-arxiv

* Update GCN on ogbn-arxiv

* Fix typo

* Use evaluator to get results

* Fix duplicated

* Fix duplicated

* Update GCN on ogbn-arxiv

* Add GAT for ogbn-arxiv

* Update README.md

* Update GAT on ogbn-arxiv

* Update README.md

* Update README

* Update README.md

* Update README.md

* Add GAT implementation for ogbn-proteins

* Add GAT implementation for ogbn-proteins

* Add GAT implementation for ogbn-proteins

* Add GAT implementation for ogbn-proteins

* Add GAT implementation for ogbn-proteins

* Update examples for ogbn-arxiv

* Update examples for ogbn-arxiv

* Update examples for ogbn-arxiv

* Update examples for ogbn-arxiv

* Update examples for ogbn-arxiv

* Update examples for ogbn-arxiv

* Update examples for ogbn-arxiv

* Update examples for ogbn-arxiv

* Update examples for ogbn-arxiv

* Update examples for ogbn-arxiv

* Update examples for ogbn-arxiv

* Update examples for ogbn-arxiv

* Update examples for ogbn-arxiv

* Update examples for ogbn-arxiv

* Add examples for ogbn-products.

* Update examples for ogbn-proteins.

* Update examples for ogbn-proteins.

* Update examples for ogbn-proteins.

* Update examples for ogbn-proteins.

* Update examples for ogbn-proteins.

* Update examples for ogbn-proteins.

* Update examples for ogbn-proteins.

* Update examples for ogbn-proteins.

* Update examples for ogbn-proteins.

* Update examples for ogbn-products.

* Update examples for ogbn-products.

* Update examples for ogbn-arxiv.

* Update examples for ogbn-arxiv.

* Update examples for ogbn-arxiv.

* Update examples for ogbn-proteins.

* Update examples for ogbn-products.

* Update examples for ogbn-proteins.

* Update examples for ogbn-arxiv.

* Update examples for ogbn-proteins.

* Update README.md

* [Example] ogbn-arxiv & ogbn-proteins

* [Example] ogbn-arxiv & ogbn-proteins

* [Example] ogbn-arxiv & ogbn-proteins

* [Example] ogbn-arxiv & ogbn-proteins
Co-authored-by: default avatarMufei Li <mufeili1996@gmail.com>
Co-authored-by: default avatarZihao Ye <expye@outlook.com>
parent 89322856
# DGL examples for ogbn-arxiv # DGL examples for ogbn-arxiv
DGL implementation of GCN and GAT for [ogbn-arxiv](https://ogb.stanford.edu/docs/nodeprop/). Using some of the techniques from *Bag of Tricks for Node Classification with Graph Neural Networks* ([https://arxiv.org/abs/2103.13355](https://arxiv.org/abs/2103.13355)).
Requires DGL 0.5 or later versions. Requires DGL 0.5 or later versions.
### GCN ### GCN
Run `gcn.py` with `--use-linear` and `--use-labels` enabled and you should directly see the result. For the best score, run `gcn.py` with `--use-linear` and `--use-labels` enabled and you should directly see the result.
```bash ```bash
python3 gcn.py --use-linear --use-labels python3 gcn.py --use-linear --use-labels
...@@ -12,10 +14,23 @@ python3 gcn.py --use-linear --use-labels ...@@ -12,10 +14,23 @@ python3 gcn.py --use-linear --use-labels
### GAT ### GAT
Run `gat.py` with `--use-labels` enabled and you should directly see the result. For the score of `GAT(norm. adj.)+labels`, run the following command and you should directly see the result.
```bash
python3 gat.py --use-norm --use-labels --no-attn-dst --edge-drop=0.1 --input-drop=0.1
```
For the score of `GAT(norm. adj.)+label reuse`, run the following command and you should directly see the result.
```bash ```bash
python3 gat.py --use-norm --use-labels python3 gat.py --use-norm --use-labels --n-label-iters=1 --no-attn-dst --edge-drop=0.3 --input-drop=0.25
```
For the score of `GAT(norm. adj.)+label reuse+C&S`, run the following command and you should directly see the result.
```bash
python3 gat.py --use-norm --use-labels --n-label-iters=1 --no-attn-dst --edge-drop=0.3 --input-drop=0.25 --save-pred
python3 correct_and_smooth.py --use-norm
``` ```
## Usage & Options ## Usage & Options
...@@ -23,59 +38,72 @@ python3 gat.py --use-norm --use-labels ...@@ -23,59 +38,72 @@ python3 gat.py --use-norm --use-labels
### GCN ### GCN
``` ```
usage: GCN on OGBN-Arxiv [-h] [--cpu] [--gpu GPU] [--n-runs N_RUNS] [--n-epochs N_EPOCHS] [--use-labels] [--use-linear] usage: GCN on OGBN-Arxiv [-h] [--cpu] [--gpu GPU] [--n-runs N_RUNS] [--n-epochs N_EPOCHS] [--use-labels] [--use-linear] [--lr LR] [--n-layers N_LAYERS] [--n-hidden N_HIDDEN]
[--lr LR] [--n-layers N_LAYERS] [--n-hidden N_HIDDEN] [--dropout DROPOUT] [--wd WD] [--dropout DROPOUT] [--wd WD] [--log-every LOG_EVERY] [--plot-curves]
[--log-every LOG_EVERY] [--plot-curves]
optional arguments: optional arguments:
-h, --help show this help message and exit -h, --help show this help message and exit
--cpu CPU mode. This option overrides --gpu. (default: False) --cpu CPU mode. This option overrides --gpu. (default: False)
--gpu GPU GPU device ID. (default: 0) --gpu GPU GPU device ID. (default: 0)
--n-runs N_RUNS --n-runs N_RUNS running times (default: 10)
--n-epochs N_EPOCHS --n-epochs N_EPOCHS number of epochs (default: 1000)
--use-labels Use labels in the training set as input features. (default: False) --use-labels Use labels in the training set as input features. (default: False)
--use-linear Use linear layer. (default: False) --use-linear Use linear layer. (default: False)
--lr LR --lr LR learning rate (default: 0.005)
--n-layers N_LAYERS --n-layers N_LAYERS number of layers (default: 3)
--n-hidden N_HIDDEN --n-hidden N_HIDDEN number of hidden units (default: 256)
--dropout DROPOUT --dropout DROPOUT dropout rate (default: 0.75)
--wd WD --wd WD weight decay (default: 0)
--log-every LOG_EVERY --log-every LOG_EVERY
--plot-curves log every LOG_EVERY epochs (default: 20)
--plot-curves plot learning curves (default: False)
``` ```
### GAT ### GAT
``` ```
usage: GAT on OGBN-Arxiv [-h] [--cpu] [--gpu GPU] [--n-runs N_RUNS] [--n-epochs N_EPOCHS] [--use-labels] [--use-norm] usage: GAT on OGBN-Arxiv [-h] [--cpu] [--gpu GPU] [--n-runs N_RUNS] [--n-epochs N_EPOCHS] [--use-labels] [--n-label-iters N_LABEL_ITERS] [--no-attn-dst]
[--lr LR] [--n-layers N_LAYERS] [--n-heads N_HEADS] [--n-hidden N_HIDDEN] [--dropout DROPOUT] [--use-norm] [--lr LR] [--n-layers N_LAYERS] [--n-heads N_HEADS] [--n-hidden N_HIDDEN] [--dropout DROPOUT] [--input-drop INPUT_DROP]
[--attn_drop ATTN_DROP] [--wd WD] [--log-every LOG_EVERY] [--plot-curves] [--attn-drop ATTN_DROP] [--edge-drop EDGE_DROP] [--wd WD] [--log-every LOG_EVERY] [--plot-curves]
optional arguments: optional arguments:
-h, --help show this help message and exit -h, --help show this help message and exit
--cpu CPU mode. This option overrides --gpu. (default: False) --cpu CPU mode. This option overrides --gpu. (default: False)
--gpu GPU GPU device ID. (default: 0) --gpu GPU GPU device ID. (default: 0)
--n-runs N_RUNS --n-runs N_RUNS running times (default: 10)
--n-epochs N_EPOCHS --n-epochs N_EPOCHS number of epochs (default: 2000)
--use-labels Use labels in the training set as input features. (default: False) --use-labels Use labels in the training set as input features. (default: False)
--n-label-iters N_LABEL_ITERS
number of label iterations (default: 0)
--no-attn-dst Don't use attn_dst. (default: False)
--use-norm Use symmetrically normalized adjacency matrix. (default: False) --use-norm Use symmetrically normalized adjacency matrix. (default: False)
--lr LR --lr LR learning rate (default: 0.002)
--n-layers N_LAYERS --n-layers N_LAYERS number of layers (default: 3)
--n-heads N_HEADS --n-heads N_HEADS number of heads (default: 3)
--n-hidden N_HIDDEN --n-hidden N_HIDDEN number of hidden units (default: 250)
--dropout DROPOUT --dropout DROPOUT dropout rate (default: 0.75)
--attn_drop ATTN_DROP --input-drop INPUT_DROP
--wd WD input drop rate (default: 0.1)
--attn-drop ATTN_DROP
attention dropout rate (default: 0.0)
--edge-drop EDGE_DROP
edge drop rate (default: 0.0)
--wd WD weight decay (default: 0)
--log-every LOG_EVERY --log-every LOG_EVERY
--plot-curves log every LOG_EVERY epochs (default: 20)
--plot-curves plot learning curves (default: False)
``` ```
## Results ## Results
Here are the results over 10 runs. Here are the results over at least 10 runs.
| | GCN | GCN+linear | GCN+labels | GCN+linear+labels | GAT*+labels | | Method | Validation Accuracy | Test Accuracy | #Parameters |
|-------------|:---------------:|:---------------:|:---------------:|:-----------------:|:---------------:| |:-------------------------------:|:-------------------:|:---------------:|:-----------:|
| Val acc | 0.7361 ± 0.0009 | 0.7397 ± 0.0010 | 0.7399 ± 0.0008 | 0.7442 ± 0.0012 | 0.7504 ± 0.0006 | | GCN | 0.7361 ± 0.0009 | 0.7246 ± 0.0021 | 109,608 |
| Test acc | 0.7246 ± 0.0021 | 0.7270 ± 0.0016 | 0.7259 ± 0.0006 | 0.7306 ± 0.0024 | 0.7365 ± 0.0011 | | GCN+linear | 0.7397 ± 0.0010 | 0.7270 ± 0.0016 | 218,152 |
| #Parameters | 109608 | 218152 | 119848 | 238632 | 1628440 | | GCN+labels | 0.7399 ± 0.0008 | 0.7259 ± 0.0006 | 119,848 |
| GCN+linear+labels | 0.7442 ± 0.0012 | 0.7306 ± 0.0024 | 238,632 |
| GAT(norm. adj.)+labels | 0.7508 ± 0.0009 | 0.7366 ± 0.0011 | 1,441,580 |
| GAT(norm. adj.)+label reuse | 0.7516 ± 0.0008 | 0.7391 ± 0.0012 | 1,441,580 |
| GAT(norm. adj.)+label reuse+C&S | 0.7519 ± 0.0008 | 0.7395 ± 0.0012 | 1,441,580 |
import argparse
import glob
import numpy as np
import torch
import torch.nn.functional as F
from dgl import function as fn
from ogb.nodeproppred import DglNodePropPredDataset, Evaluator
device = None
dataset = "ogbn-arxiv"
n_node_feats, n_classes = 0, 0
def load_data(dataset):
global n_node_feats, n_classes
data = DglNodePropPredDataset(name=dataset)
evaluator = Evaluator(name=dataset)
splitted_idx = data.get_idx_split()
train_idx, val_idx, test_idx = splitted_idx["train"], splitted_idx["valid"], splitted_idx["test"]
graph, labels = data[0]
n_node_feats = graph.ndata["feat"].shape[1]
n_classes = (labels.max() + 1).item()
return graph, labels, train_idx, val_idx, test_idx, evaluator
def preprocess(graph):
global n_node_feats
# add reverse edges
srcs, dsts = graph.all_edges()
graph.add_edges(dsts, srcs)
# add self-loop
print(f"Total edges before adding self-loop {graph.number_of_edges()}")
graph = graph.remove_self_loop().add_self_loop()
print(f"Total edges after adding self-loop {graph.number_of_edges()}")
graph.create_formats_()
return graph
def general_outcome_correlation(graph, y0, n_prop=50, alpha=0.8, use_norm=False, post_step=None):
with graph.local_scope():
y = y0
for _ in range(n_prop):
if use_norm:
degs = graph.in_degrees().float().clamp(min=1)
norm = torch.pow(degs, -0.5)
shp = norm.shape + (1,) * (y.dim() - 1)
norm = torch.reshape(norm, shp)
y = y * norm
graph.srcdata.update({"y": y})
graph.update_all(fn.copy_u("y", "m"), fn.mean("m", "y"))
y = graph.dstdata["y"]
if use_norm:
degs = graph.in_degrees().float().clamp(min=1)
norm = torch.pow(degs, 0.5)
shp = norm.shape + (1,) * (y.dim() - 1)
norm = torch.reshape(norm, shp)
y = y * norm
y = alpha * y + (1 - alpha) * y0
if post_step is not None:
y = post_step(y)
return y
def evaluate(labels, pred, train_idx, val_idx, test_idx, evaluator):
return (
evaluator(pred[train_idx], labels[train_idx]),
evaluator(pred[val_idx], labels[val_idx]),
evaluator(pred[test_idx], labels[test_idx]),
)
def run(args, graph, labels, pred, train_idx, val_idx, test_idx, evaluator):
evaluator_wrapper = lambda pred, labels: evaluator.eval(
{"y_pred": pred.argmax(dim=-1, keepdim=True), "y_true": labels}
)["acc"]
y = pred.clone()
y[train_idx] = F.one_hot(labels[train_idx], n_classes).float().squeeze(1)
# dy = torch.zeros(graph.number_of_nodes(), n_classes, device=device)
# dy[train_idx] = F.one_hot(labels[train_idx], n_classes).float().squeeze(1) - pred[train_idx]
_train_acc, val_acc, test_acc = evaluate(labels, y, train_idx, val_idx, test_idx, evaluator_wrapper)
# print("train acc:", _train_acc)
print("original val acc:", val_acc)
print("original test acc:", test_acc)
# NOTE: Only "smooth" is performed here.
# smoothed_dy = general_outcome_correlation(
# graph, dy, alpha=args.alpha1, use_norm=args.use_norm, post_step=lambda x: x.clamp(-1, 1)
# )
# y[train_idx] = F.one_hot(labels[train_idx], n_classes).float().squeeze(1)
# smoothed_dy = smoothed_dy
# y = y + args.alpha2 * smoothed_dy # .clamp(0, 1)
smoothed_y = general_outcome_correlation(
graph, y, alpha=args.alpha, use_norm=args.use_norm, post_step=lambda x: x.clamp(0, 1)
)
_train_acc, val_acc, test_acc = evaluate(labels, smoothed_y, train_idx, val_idx, test_idx, evaluator_wrapper)
# print("train acc:", _train_acc)
print("val acc:", val_acc)
print("test acc:", test_acc)
return val_acc, test_acc
def main():
global device
argparser = argparse.ArgumentParser(description="implementation of C&S)")
argparser.add_argument("--cpu", action="store_true", help="CPU mode. This option overrides --gpu.")
argparser.add_argument("--gpu", type=int, default=0, help="GPU device ID.")
argparser.add_argument("--use-norm", action="store_true", help="Use symmetrically normalized adjacency matrix.")
argparser.add_argument("--alpha", type=float, default=0.6, help="alpha")
argparser.add_argument("--pred-files", type=str, default="./output/*.pt", help="address of prediction files")
args = argparser.parse_args()
if args.cpu:
device = torch.device("cpu")
else:
device = torch.device(f"cuda:{args.gpu}")
# load data & preprocess
graph, labels, train_idx, val_idx, test_idx, evaluator = load_data(dataset)
graph = preprocess(graph)
graph, labels, train_idx, val_idx, test_idx = map(
lambda x: x.to(device), (graph, labels, train_idx, val_idx, test_idx)
)
# run
val_accs, test_accs = [], []
for pred_file in glob.iglob(args.pred_files):
print("load:", pred_file)
pred = torch.load(pred_file)
val_acc, test_acc = run(args, graph, labels, pred, train_idx, val_idx, test_idx, evaluator)
val_accs.append(val_acc)
test_accs.append(test_acc)
print(args)
print(f"Runned {len(val_accs)} times")
print("Val Accs:", val_accs)
print("Test Accs:", test_accs)
print(f"Average val accuracy: {np.mean(val_accs)} ± {np.std(val_accs)}")
print(f"Average test accuracy: {np.mean(test_accs)} ± {np.std(test_accs)}")
if __name__ == "__main__":
main()
# Namespace(alpha=0.6, cpu=False, gpu=0, pred_files='./output/*.pt', use_norm=True)
# Runned 20 times
# Val Accs: [0.7523742407463337, 0.750729890264774, 0.7524077989194268, 0.7527098224772644, 0.752508473438706, 0.7509983556495184, 0.751904426323031, 0.7514010537266351, 0.7524077989194268, 0.753716567670056, 0.7523071244001477, 0.7518373099768448, 0.7528440551696366, 0.7509983556495184, 0.7521057753615893, 0.7520386590154032, 0.7500251686298198, 0.7513674955535421, 0.7509312393033323, 0.7518037518037518]
# Test Accs: [0.7392753533732486, 0.7381437359833755, 0.7412093903668497, 0.7402629467316832, 0.7386169578009588, 0.7380408616752052, 0.7397280003291978, 0.7401189227002448, 0.7424233072032591, 0.7397280003291978, 0.7378351130588647, 0.7400160483920746, 0.740921342303973, 0.7385758080776906, 0.7411682406435817, 0.7389667304487377, 0.7396457008826616, 0.7384935086311545, 0.7396251260210275, 0.7379997119519371]
# Average val accuracy: 0.751870868149938 ± 0.0008415008835817228
# Average test accuracy: 0.7395397403452462 ± 0.0012162384423867229
This diff is collapsed.
...@@ -116,7 +116,7 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running) ...@@ -116,7 +116,7 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running)
# training loop # training loop
total_time = 0 total_time = 0
best_val_acc, best_test_acc, best_val_loss = 0, 0, float("inf") best_val_acc, final_test_acc, best_val_loss = 0, 0, float("inf")
accs, train_accs, val_accs, test_accs = [], [], [], [] accs, train_accs, val_accs, test_accs = [], [], [], []
losses, train_losses, val_losses, test_losses = [], [], [], [] losses, train_losses, val_losses, test_losses = [], [], [], []
...@@ -138,18 +138,17 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running) ...@@ -138,18 +138,17 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running)
toc = time.time() toc = time.time()
total_time += toc - tic total_time += toc - tic
# if val_acc > best_val_acc:
if val_loss < best_val_loss: if val_loss < best_val_loss:
best_val_loss = val_loss best_val_loss = val_loss
best_val_acc = val_acc best_val_acc = val_acc
best_test_acc = test_acc final_test_acc = test_acc
if epoch % args.log_every == 0: if epoch % args.log_every == 0:
print(f"Epoch: {epoch}/{args.n_epochs}")
print( print(
f"Run: {n_running}/{args.n_runs}, Epoch: {epoch}/{args.n_epochs}, Average epoch time: {total_time / epoch:.2f}\n"
f"Loss: {loss.item():.4f}, Acc: {acc:.4f}\n" f"Loss: {loss.item():.4f}, Acc: {acc:.4f}\n"
f"Train/Val/Test loss: {train_loss:.4f}/{val_loss:.4f}/{test_loss:.4f}\n" f"Train/Val/Test loss: {train_loss:.4f}/{val_loss:.4f}/{test_loss:.4f}\n"
f"Train/Val/Test/Best val/Best test acc: {train_acc:.4f}/{val_acc:.4f}/{test_acc:.4f}/{best_val_acc:.4f}/{best_test_acc:.4f}" f"Train/Val/Test/Best val/Final test acc: {train_acc:.4f}/{val_acc:.4f}/{test_acc:.4f}/{best_val_acc:.4f}/{final_test_acc:.4f}"
) )
for l, e in zip( for l, e in zip(
...@@ -159,7 +158,8 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running) ...@@ -159,7 +158,8 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running)
l.append(e) l.append(e)
print("*" * 50) print("*" * 50)
print(f"Average epoch time: {total_time / args.n_epochs}, Test acc: {best_test_acc}") print(f"Best val acc: {best_val_acc}, Final test acc: {final_test_acc}")
print("*" * 50)
if args.plot_curves: if args.plot_curves:
fig = plt.figure(figsize=(24, 24)) fig = plt.figure(figsize=(24, 24))
...@@ -197,7 +197,7 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running) ...@@ -197,7 +197,7 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running)
plt.tight_layout() plt.tight_layout()
plt.savefig(f"gcn_loss_{n_running}.png") plt.savefig(f"gcn_loss_{n_running}.png")
return best_val_acc, best_test_acc return best_val_acc, final_test_acc
def count_parameters(args): def count_parameters(args):
...@@ -211,19 +211,19 @@ def main(): ...@@ -211,19 +211,19 @@ def main():
argparser = argparse.ArgumentParser("GCN on OGBN-Arxiv", formatter_class=argparse.ArgumentDefaultsHelpFormatter) argparser = argparse.ArgumentParser("GCN on OGBN-Arxiv", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
argparser.add_argument("--cpu", action="store_true", help="CPU mode. This option overrides --gpu.") argparser.add_argument("--cpu", action="store_true", help="CPU mode. This option overrides --gpu.")
argparser.add_argument("--gpu", type=int, default=0, help="GPU device ID.") argparser.add_argument("--gpu", type=int, default=0, help="GPU device ID.")
argparser.add_argument("--n-runs", type=int, default=10) argparser.add_argument("--n-runs", type=int, default=10, help="running times")
argparser.add_argument("--n-epochs", type=int, default=1000) argparser.add_argument("--n-epochs", type=int, default=1000, help="number of epochs")
argparser.add_argument( argparser.add_argument(
"--use-labels", action="store_true", help="Use labels in the training set as input features." "--use-labels", action="store_true", help="Use labels in the training set as input features."
) )
argparser.add_argument("--use-linear", action="store_true", help="Use linear layer.") argparser.add_argument("--use-linear", action="store_true", help="Use linear layer.")
argparser.add_argument("--lr", type=float, default=0.005) argparser.add_argument("--lr", type=float, default=0.005, help="learning rate")
argparser.add_argument("--n-layers", type=int, default=3) argparser.add_argument("--n-layers", type=int, default=3, help="number of layers")
argparser.add_argument("--n-hidden", type=int, default=256) argparser.add_argument("--n-hidden", type=int, default=256, help="number of hidden units")
argparser.add_argument("--dropout", type=float, default=0.5) argparser.add_argument("--dropout", type=float, default=0.5, help="dropout rate")
argparser.add_argument("--wd", type=float, default=0) argparser.add_argument("--wd", type=float, default=0, help="weight decay")
argparser.add_argument("--log-every", type=int, default=20) argparser.add_argument("--log-every", type=int, default=20, help="log every LOG_EVERY epochs")
argparser.add_argument("--plot-curves", action="store_true") argparser.add_argument("--plot-curves", action="store_true", help="plot learning curves")
args = argparser.parse_args() args = argparser.parse_args()
if args.cpu: if args.cpu:
...@@ -250,7 +250,7 @@ def main(): ...@@ -250,7 +250,7 @@ def main():
in_feats = graph.ndata["feat"].shape[1] in_feats = graph.ndata["feat"].shape[1]
n_classes = (labels.max() + 1).item() n_classes = (labels.max() + 1).item()
# graph.create_format_() graph.create_formats_()
train_idx = train_idx.to(device) train_idx = train_idx.to(device)
val_idx = val_idx.to(device) val_idx = val_idx.to(device)
......
...@@ -2,24 +2,43 @@ import dgl.nn.pytorch as dglnn ...@@ -2,24 +2,43 @@ import dgl.nn.pytorch as dglnn
import torch import torch
import torch.nn as nn import torch.nn as nn
from dgl import function as fn from dgl import function as fn
from dgl._ffi.base import DGLError from dgl.ops import edge_softmax
from dgl.nn.pytorch.utils import Identity
from dgl.nn.functional import edge_softmax
from dgl.utils import expand_as_pair from dgl.utils import expand_as_pair
class Bias(nn.Module): class ElementWiseLinear(nn.Module):
def __init__(self, size): def __init__(self, size, weight=True, bias=True, inplace=False):
super().__init__() super().__init__()
self.bias = nn.Parameter(torch.Tensor(size)) if weight:
self.weight = nn.Parameter(torch.Tensor(size))
else:
self.weight = None
if bias:
self.bias = nn.Parameter(torch.Tensor(size))
else:
self.bias = None
self.inplace = inplace
self.reset_parameters() self.reset_parameters()
def reset_parameters(self): def reset_parameters(self):
nn.init.zeros_(self.bias) if self.weight is not None:
nn.init.ones_(self.weight)
if self.bias is not None:
nn.init.zeros_(self.bias)
def forward(self, x): def forward(self, x):
return x + self.bias if self.inplace:
if self.weight is not None:
x.mul_(self.weight)
if self.bias is not None:
x.add_(self.bias)
else:
if self.weight is not None:
x = x * self.weight
if self.bias is not None:
x = x + self.bias
return x
class GCN(nn.Module): class GCN(nn.Module):
...@@ -33,7 +52,7 @@ class GCN(nn.Module): ...@@ -33,7 +52,7 @@ class GCN(nn.Module):
self.convs = nn.ModuleList() self.convs = nn.ModuleList()
if use_linear: if use_linear:
self.linear = nn.ModuleList() self.linear = nn.ModuleList()
self.bns = nn.ModuleList() self.norms = nn.ModuleList()
for i in range(n_layers): for i in range(n_layers):
in_hidden = n_hidden if i > 0 else in_feats in_hidden = n_hidden if i > 0 else in_feats
...@@ -44,15 +63,15 @@ class GCN(nn.Module): ...@@ -44,15 +63,15 @@ class GCN(nn.Module):
if use_linear: if use_linear:
self.linear.append(nn.Linear(in_hidden, out_hidden, bias=False)) self.linear.append(nn.Linear(in_hidden, out_hidden, bias=False))
if i < n_layers - 1: if i < n_layers - 1:
self.bns.append(nn.BatchNorm1d(out_hidden)) self.norms.append(nn.BatchNorm1d(out_hidden))
self.dropout0 = nn.Dropout(min(0.1, dropout)) self.input_drop = nn.Dropout(min(0.1, dropout))
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
self.activation = activation self.activation = activation
def forward(self, graph, feat): def forward(self, graph, feat):
h = feat h = feat
h = self.dropout0(h) h = self.input_drop(h)
for i in range(self.n_layers): for i in range(self.n_layers):
conv = self.convs[i](graph, h) conv = self.convs[i](graph, h)
...@@ -64,7 +83,7 @@ class GCN(nn.Module): ...@@ -64,7 +83,7 @@ class GCN(nn.Module):
h = conv h = conv
if i < self.n_layers - 1: if i < self.n_layers - 1:
h = self.bns[i](h) h = self.norms[i](h)
h = self.activation(h) h = self.activation(h)
h = self.dropout(h) h = self.dropout(h)
...@@ -79,35 +98,36 @@ class GATConv(nn.Module): ...@@ -79,35 +98,36 @@ class GATConv(nn.Module):
num_heads=1, num_heads=1,
feat_drop=0.0, feat_drop=0.0,
attn_drop=0.0, attn_drop=0.0,
edge_drop=0.0,
negative_slope=0.2, negative_slope=0.2,
use_attn_dst=True,
residual=False, residual=False,
activation=None, activation=None,
allow_zero_in_degree=False, allow_zero_in_degree=False,
norm="none", use_symmetric_norm=False,
): ):
super(GATConv, self).__init__() super(GATConv, self).__init__()
if norm not in ("none", "both"):
raise DGLError('Invalid norm value. Must be either "none", "both".' ' But got "{}".'.format(norm))
self._num_heads = num_heads self._num_heads = num_heads
self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats) self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
self._out_feats = out_feats self._out_feats = out_feats
self._allow_zero_in_degree = allow_zero_in_degree self._allow_zero_in_degree = allow_zero_in_degree
self._norm = norm self._use_symmetric_norm = use_symmetric_norm
if isinstance(in_feats, tuple): if isinstance(in_feats, tuple):
self.fc_src = nn.Linear(self._in_src_feats, out_feats * num_heads, bias=False) self.fc_src = nn.Linear(self._in_src_feats, out_feats * num_heads, bias=False)
self.fc_dst = nn.Linear(self._in_dst_feats, out_feats * num_heads, bias=False) self.fc_dst = nn.Linear(self._in_dst_feats, out_feats * num_heads, bias=False)
else: else:
self.fc = nn.Linear(self._in_src_feats, out_feats * num_heads, bias=False) self.fc = nn.Linear(self._in_src_feats, out_feats * num_heads, bias=False)
self.attn_l = nn.Parameter(torch.FloatTensor(size=(1, num_heads, out_feats))) self.attn_l = nn.Parameter(torch.FloatTensor(size=(1, num_heads, out_feats)))
self.attn_r = nn.Parameter(torch.FloatTensor(size=(1, num_heads, out_feats))) if use_attn_dst:
self.attn_r = nn.Parameter(torch.FloatTensor(size=(1, num_heads, out_feats)))
else:
self.register_buffer("attn_r", None)
self.feat_drop = nn.Dropout(feat_drop) self.feat_drop = nn.Dropout(feat_drop)
self.attn_drop = nn.Dropout(attn_drop) self.attn_drop = nn.Dropout(attn_drop)
self.edge_drop = edge_drop
self.leaky_relu = nn.LeakyReLU(negative_slope) self.leaky_relu = nn.LeakyReLU(negative_slope)
if residual: if residual:
if self._in_dst_feats != out_feats: self.res_fc = nn.Linear(self._in_dst_feats, num_heads * out_feats, bias=False)
self.res_fc = nn.Linear(self._in_dst_feats, num_heads * out_feats, bias=False)
else:
self.res_fc = Identity()
else: else:
self.register_buffer("res_fc", None) self.register_buffer("res_fc", None)
self.reset_parameters() self.reset_parameters()
...@@ -121,7 +141,8 @@ class GATConv(nn.Module): ...@@ -121,7 +141,8 @@ class GATConv(nn.Module):
nn.init.xavier_normal_(self.fc_src.weight, gain=gain) nn.init.xavier_normal_(self.fc_src.weight, gain=gain)
nn.init.xavier_normal_(self.fc_dst.weight, gain=gain) nn.init.xavier_normal_(self.fc_dst.weight, gain=gain)
nn.init.xavier_normal_(self.attn_l, gain=gain) nn.init.xavier_normal_(self.attn_l, gain=gain)
nn.init.xavier_normal_(self.attn_r, gain=gain) if isinstance(self.attn_r, nn.Parameter):
nn.init.xavier_normal_(self.attn_r, gain=gain)
if isinstance(self.res_fc, nn.Linear): if isinstance(self.res_fc, nn.Linear):
nn.init.xavier_normal_(self.res_fc.weight, gain=gain) nn.init.xavier_normal_(self.res_fc.weight, gain=gain)
...@@ -143,13 +164,17 @@ class GATConv(nn.Module): ...@@ -143,13 +164,17 @@ class GATConv(nn.Module):
feat_src = self.fc_src(h_src).view(-1, self._num_heads, self._out_feats) feat_src = self.fc_src(h_src).view(-1, self._num_heads, self._out_feats)
feat_dst = self.fc_dst(h_dst).view(-1, self._num_heads, self._out_feats) feat_dst = self.fc_dst(h_dst).view(-1, self._num_heads, self._out_feats)
else: else:
h_src = h_dst = self.feat_drop(feat) h_src = self.feat_drop(feat)
feat_src, feat_dst = h_src, h_dst feat_src = h_src
feat_src = feat_dst = self.fc(h_src).view(-1, self._num_heads, self._out_feats) feat_src = self.fc(h_src).view(-1, self._num_heads, self._out_feats)
if graph.is_block: if graph.is_block:
h_dst = h_src[: graph.number_of_dst_nodes()]
feat_dst = feat_src[: graph.number_of_dst_nodes()] feat_dst = feat_src[: graph.number_of_dst_nodes()]
else:
h_dst = h_src
feat_dst = feat_src
if self._norm == "both": if self._use_symmetric_norm:
degs = graph.out_degrees().float().clamp(min=1) degs = graph.out_degrees().float().clamp(min=1)
norm = torch.pow(degs, -0.5) norm = torch.pow(degs, -0.5)
shp = norm.shape + (1,) * (feat_src.dim() - 1) shp = norm.shape + (1,) * (feat_src.dim() - 1)
...@@ -167,19 +192,30 @@ class GATConv(nn.Module): ...@@ -167,19 +192,30 @@ class GATConv(nn.Module):
# addition could be optimized with DGL's built-in function u_add_v, # addition could be optimized with DGL's built-in function u_add_v,
# which further speeds up computation and saves memory footprint. # which further speeds up computation and saves memory footprint.
el = (feat_src * self.attn_l).sum(dim=-1).unsqueeze(-1) el = (feat_src * self.attn_l).sum(dim=-1).unsqueeze(-1)
er = (feat_dst * self.attn_r).sum(dim=-1).unsqueeze(-1)
graph.srcdata.update({"ft": feat_src, "el": el}) graph.srcdata.update({"ft": feat_src, "el": el})
graph.dstdata.update({"er": er})
# compute edge attention, el and er are a_l Wh_i and a_r Wh_j respectively. # compute edge attention, el and er are a_l Wh_i and a_r Wh_j respectively.
graph.apply_edges(fn.u_add_v("el", "er", "e")) if self.attn_r is not None:
er = (feat_dst * self.attn_r).sum(dim=-1).unsqueeze(-1)
graph.dstdata.update({"er": er})
graph.apply_edges(fn.u_add_v("el", "er", "e"))
else:
graph.apply_edges(fn.copy_u("el", "e"))
e = self.leaky_relu(graph.edata.pop("e")) e = self.leaky_relu(graph.edata.pop("e"))
# compute softmax
graph.edata["a"] = self.attn_drop(edge_softmax(graph, e)) if self.training and self.edge_drop > 0:
perm = torch.randperm(graph.number_of_edges(), device=e.device)
bound = int(graph.number_of_edges() * self.edge_drop)
eids = perm[bound:]
graph.edata["a"] = torch.zeros_like(e)
graph.edata["a"][eids] = self.attn_drop(edge_softmax(graph, e[eids], eids=eids))
else:
graph.edata["a"] = self.attn_drop(edge_softmax(graph, e))
# message passing # message passing
graph.update_all(fn.u_mul_e("ft", "a", "m"), fn.sum("m", "ft")) graph.update_all(fn.u_mul_e("ft", "a", "m"), fn.sum("m", "ft"))
rst = graph.dstdata["ft"] rst = graph.dstdata["ft"]
if self._norm == "both": if self._use_symmetric_norm:
degs = graph.in_degrees().float().clamp(min=1) degs = graph.in_degrees().float().clamp(min=1)
norm = torch.pow(degs, 0.5) norm = torch.pow(degs, 0.5)
shp = norm.shape + (1,) * (feat_dst.dim() - 1) shp = norm.shape + (1,) * (feat_dst.dim() - 1)
...@@ -190,15 +226,29 @@ class GATConv(nn.Module): ...@@ -190,15 +226,29 @@ class GATConv(nn.Module):
if self.res_fc is not None: if self.res_fc is not None:
resval = self.res_fc(h_dst).view(h_dst.shape[0], -1, self._out_feats) resval = self.res_fc(h_dst).view(h_dst.shape[0], -1, self._out_feats)
rst = rst + resval rst = rst + resval
# activation # activation
if self._activation is not None: if self._activation is not None:
rst = self._activation(rst) rst = self._activation(rst)
return rst return rst
class GAT(nn.Module): class GAT(nn.Module):
def __init__( def __init__(
self, in_feats, n_classes, n_hidden, n_layers, n_heads, activation, dropout=0.0, attn_drop=0.0, norm="none" self,
in_feats,
n_classes,
n_hidden,
n_layers,
n_heads,
activation,
dropout=0.0,
input_drop=0.0,
attn_drop=0.0,
edge_drop=0.0,
use_attn_dst=True,
use_symmetric_norm=False,
): ):
super().__init__() super().__init__()
self.in_feats = in_feats self.in_feats = in_feats
...@@ -208,42 +258,49 @@ class GAT(nn.Module): ...@@ -208,42 +258,49 @@ class GAT(nn.Module):
self.num_heads = n_heads self.num_heads = n_heads
self.convs = nn.ModuleList() self.convs = nn.ModuleList()
self.linear = nn.ModuleList() self.norms = nn.ModuleList()
self.bns = nn.ModuleList()
self.biases = nn.ModuleList()
for i in range(n_layers): for i in range(n_layers):
in_hidden = n_heads * n_hidden if i > 0 else in_feats in_hidden = n_heads * n_hidden if i > 0 else in_feats
out_hidden = n_hidden if i < n_layers - 1 else n_classes out_hidden = n_hidden if i < n_layers - 1 else n_classes
# in_channels = n_heads if i > 0 else 1 num_heads = n_heads if i < n_layers - 1 else 1
out_channels = n_heads out_channels = n_heads
self.convs.append(GATConv(in_hidden, out_hidden, num_heads=n_heads, attn_drop=attn_drop, norm=norm)) self.convs.append(
GATConv(
in_hidden,
out_hidden,
num_heads=num_heads,
attn_drop=attn_drop,
edge_drop=edge_drop,
use_attn_dst=use_attn_dst,
use_symmetric_norm=use_symmetric_norm,
residual=True,
)
)
self.linear.append(nn.Linear(in_hidden, out_channels * out_hidden, bias=False))
if i < n_layers - 1: if i < n_layers - 1:
self.bns.append(nn.BatchNorm1d(out_channels * out_hidden)) self.norms.append(nn.BatchNorm1d(out_channels * out_hidden))
self.bias_last = Bias(n_classes) self.bias_last = ElementWiseLinear(n_classes, weight=False, bias=True, inplace=True)
self.dropout0 = nn.Dropout(min(0.1, dropout)) self.input_drop = nn.Dropout(input_drop)
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
self.activation = activation self.activation = activation
def forward(self, graph, feat): def forward(self, graph, feat):
h = feat h = feat
h = self.dropout0(h) h = self.input_drop(h)
for i in range(self.n_layers): for i in range(self.n_layers):
conv = self.convs[i](graph, h) conv = self.convs[i](graph, h)
linear = self.linear[i](h).view(conv.shape)
h = conv + linear h = conv
if i < self.n_layers - 1: if i < self.n_layers - 1:
h = h.flatten(1) h = h.flatten(1)
h = self.bns[i](h) h = self.norms[i](h)
h = self.activation(h) h = self.activation(h, inplace=True)
h = self.dropout(h) h = self.dropout(h)
h = h.mean(1) h = h.mean(1)
......
# Sample-based GAT on OGB Products # DGL examples for ogbn-products
## Sample-based GAT
Requires DGL 0.4.3post2 or later versions. Requires DGL 0.4.3post2 or later versions.
Run `main.py` and you should directly see the result. Run `main.py` and you should directly see the result.
Accuracy over 5 runs: 0.7863197 ± 0.00072568655 Accuracy over 5 runs: 0.7863197 ± 0.00072568655
## GAT (another implementation)
Requires DGL 0.5 or later versions.
For the score of `GAT`, run the following command and you should directly see the result.
```bash
python3 gat.py
```
Or, if you want to speed up during training time, run with `--estimation-mode` enabled.
This option will do a complete evaluation when the training is over.
```bash
python3 gat.py --estimation-mode
```
## Results
Here are the results over 10 runs.
| Method | Validation Accuracy | Test Accuracy | #Parameters |
|:-------------:|:-------------------:|:---------------:|:-----------:|
| GAT (main.py) | N/A | 0.7863 ± 0.0007 | N/A |
| GAT (gat.py) | 0.9327 ± 0.0003 | 0.8126 ± 0.0018 | 1,065,127 |
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import argparse
import math
import random
import time
from collections import OrderedDict
import dgl
import dgl.function as fn
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
from dgl.dataloading import MultiLayerFullNeighborSampler, MultiLayerNeighborSampler
from dgl.dataloading.pytorch import NodeDataLoader
from matplotlib.ticker import AutoMinorLocator, MultipleLocator
from ogb.nodeproppred import DglNodePropPredDataset, Evaluator
from torch import nn
from tqdm import tqdm
from models import GAT
from utils import BatchSampler, DataLoaderWrapper
epsilon = 1 - math.log(2)
device = None
dataset = "ogbn-products"
n_node_feats, n_edge_feats, n_classes = 0, 0, 0
def seed(seed=0):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
dgl.random.seed(seed)
def load_data(dataset):
data = DglNodePropPredDataset(name=dataset)
evaluator = Evaluator(name=dataset)
splitted_idx = data.get_idx_split()
train_idx, val_idx, test_idx = splitted_idx["train"], splitted_idx["valid"], splitted_idx["test"]
graph, labels = data[0]
graph.ndata["labels"] = labels
return graph, labels, train_idx, val_idx, test_idx, evaluator
def preprocess(graph, labels, train_idx):
global n_node_feats, n_classes
n_node_feats = graph.ndata["feat"].shape[1]
n_classes = (labels.max() + 1).item()
# graph = graph.remove_self_loop().add_self_loop()
n_node_feats = graph.ndata["feat"].shape[-1]
graph.ndata["train_labels_onehot"] = torch.zeros(graph.number_of_nodes(), n_classes)
graph.ndata["train_labels_onehot"][train_idx, labels[train_idx, 0]] = 1
graph.ndata["is_train"] = torch.zeros(graph.number_of_nodes(), dtype=torch.bool)
graph.ndata["is_train"][train_idx] = 1
graph.create_formats_()
return graph, labels
def gen_model(args):
if args.use_labels:
n_node_feats_ = n_node_feats + n_classes
else:
n_node_feats_ = n_node_feats
model = GAT(
n_node_feats_,
n_edge_feats,
n_classes,
n_layers=args.n_layers,
n_heads=args.n_heads,
n_hidden=args.n_hidden,
edge_emb=0,
activation=F.relu,
dropout=args.dropout,
input_drop=args.input_drop,
attn_drop=args.attn_dropout,
edge_drop=args.edge_drop,
use_attn_dst=not args.use_attn_dst,
allow_zero_in_degree=True,
residual=False,
)
return model
def custom_loss_function(x, labels):
y = F.cross_entropy(x, labels[:, 0], reduction="none")
y = torch.log(epsilon + y) - math.log(epsilon)
return torch.mean(y)
def add_soft_labels(graph, soft_labels):
feat = graph.srcdata["feat"]
graph.srcdata["feat"] = torch.cat([feat, soft_labels], dim=-1)
def update_hard_labels(graph, idx=None):
if idx is None:
idx = torch.arange(graph.srcdata["is_train"].shape[0])[graph.srcdata["is_train"]]
graph.srcdata["feat"][idx, -n_classes:] = graph.srcdata["train_labels_onehot"][idx]
def train(args, model, dataloader, labels, train_idx, criterion, optimizer, evaluator):
model.train()
loss_sum, total = 0, 0
preds = torch.zeros(labels.shape[0], n_classes)
for it in range(args.n_label_iters + 1):
preds_old = preds.clone()
for input_nodes, output_nodes, subgraphs in dataloader:
subgraphs = [b.to(device) for b in subgraphs]
new_train_idx = torch.arange(len(output_nodes))
if args.use_labels:
mask = torch.rand(new_train_idx.shape) < args.mask_rate
train_labels_idx = torch.cat([new_train_idx[~mask], torch.arange(len(output_nodes), len(input_nodes))])
train_pred_idx = new_train_idx[mask]
add_soft_labels(subgraphs[0], F.softmax(preds_old[input_nodes].to(device), dim=-1))
update_hard_labels(subgraphs[0], train_labels_idx)
else:
train_pred_idx = new_train_idx
pred = model(subgraphs)
preds[output_nodes] = pred.cpu().detach()
# NOTE: This is not a complete implementation of label reuse, since it is too expensive
# to predict the nodes in validation and test set during training time.
if it == args.n_label_iters:
loss = criterion(pred[train_pred_idx], subgraphs[-1].dstdata["labels"][train_pred_idx])
optimizer.zero_grad()
loss.backward()
optimizer.step()
count = len(train_pred_idx)
loss_sum += loss.item() * count
total += count
torch.cuda.empty_cache()
return (
evaluator(preds[train_idx], labels[train_idx]),
loss_sum / total,
)
@torch.no_grad()
def evaluate(args, model, dataloader, labels, train_idx, val_idx, test_idx, criterion, evaluator):
model.eval()
# Due to the limitation of memory capacity, we calculate the average of logits 'eval_times' times.
eval_times = 1
preds_avg = torch.zeros(labels.shape[0], n_classes)
for _ in range(eval_times):
preds = torch.zeros(labels.shape[0], n_classes)
for _it in range(args.n_label_iters + 1):
preds_old = preds.clone()
for input_nodes, output_nodes, subgraphs in dataloader:
subgraphs = [b.to(device) for b in subgraphs]
if args.use_labels:
add_soft_labels(subgraphs[0], F.softmax(preds_old[input_nodes].to(device), dim=-1))
update_hard_labels(subgraphs[0])
pred = model(subgraphs, inference=True)
preds[output_nodes] = pred.cpu()
torch.cuda.empty_cache()
preds_avg += preds
preds_avg = preds_avg.to(device)
preds_avg /= eval_times
train_loss = criterion(preds_avg[train_idx], labels[train_idx]).item()
val_loss = criterion(preds_avg[val_idx], labels[val_idx]).item()
test_loss = criterion(preds_avg[test_idx], labels[test_idx]).item()
return (
evaluator(preds_avg[train_idx], labels[train_idx]),
evaluator(preds_avg[val_idx], labels[val_idx]),
evaluator(preds_avg[test_idx], labels[test_idx]),
train_loss,
val_loss,
test_loss,
)
def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running):
evaluator_wrapper = lambda pred, labels: evaluator.eval(
{"y_pred": pred.argmax(dim=-1, keepdim=True), "y_true": labels}
)["acc"]
criterion = custom_loss_function
n_train_samples = train_idx.shape[0]
train_batch_size = (n_train_samples + 29) // 30
train_sampler = MultiLayerNeighborSampler([10 for _ in range(args.n_layers)])
train_dataloader = DataLoaderWrapper(
NodeDataLoader(
graph.cpu(),
train_idx.cpu(),
train_sampler,
batch_sampler=BatchSampler(len(train_idx), batch_size=train_batch_size, shuffle=True),
num_workers=4,
)
)
eval_batch_size = 32768
eval_sampler = MultiLayerNeighborSampler([15 for _ in range(args.n_layers)])
if args.estimation_mode:
test_idx_during_training = test_idx[torch.arange(start=0, end=len(test_idx), step=45)]
else:
test_idx_during_training = test_idx
eval_idx = torch.cat([train_idx.cpu(), val_idx.cpu(), test_idx_during_training.cpu()])
eval_dataloader = DataLoaderWrapper(
NodeDataLoader(
graph.cpu(),
eval_idx,
eval_sampler,
batch_sampler=BatchSampler(len(eval_idx), batch_size=eval_batch_size, shuffle=False),
num_workers=4,
)
)
model = gen_model(args).to(device)
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wd)
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode="max", factor=0.7, patience=20, verbose=True, min_lr=1e-4
)
best_model_state_dict = None
total_time = 0
val_score, best_val_score, final_test_score = 0, 0, 0
scores, train_scores, val_scores, test_scores = [], [], [], []
losses, train_losses, val_losses, test_losses = [], [], [], []
for epoch in range(1, args.n_epochs + 1):
tic = time.time()
score, loss = train(args, model, train_dataloader, labels, train_idx, criterion, optimizer, evaluator_wrapper)
toc = time.time()
total_time += toc - tic
if epoch == args.n_epochs or epoch % args.eval_every == 0 or epoch % args.log_every == 0:
train_score, val_score, test_score, train_loss, val_loss, test_loss = evaluate(
args,
model,
eval_dataloader,
labels,
train_idx,
val_idx,
test_idx_during_training,
criterion,
evaluator_wrapper,
)
if val_score > best_val_score:
best_val_score = val_score
final_test_score = test_score
if args.estimation_mode:
best_model_state_dict = {k: v.to("cpu") for k, v in model.state_dict().items()}
if epoch == args.n_epochs or epoch % args.log_every == 0:
print(
f"Run: {n_running}/{args.n_runs}, Epoch: {epoch}/{args.n_epochs}, Average epoch time: {total_time / epoch:.2s}\n"
f"Loss: {loss:.4f}, Score: {score:.4f}\n"
f"Train/Val/Test loss: {train_loss:.4f}/{val_loss:.4f}/{test_loss:.4f}\n"
f"Train/Val/Test/Best val/Final test score: {train_score:.4f}/{val_score:.4f}/{test_score:.4f}/{best_val_score:.4f}/{final_test_score:.4f}"
)
for l, e in zip(
[scores, train_scores, val_scores, test_scores, losses, train_losses, val_losses, test_losses],
[score, train_score, val_score, test_score, loss, train_loss, val_loss, test_loss],
):
l.append(e)
lr_scheduler.step(val_score)
if args.estimation_mode:
model.load_state_dict(best_model_state_dict)
eval_dataloader = DataLoaderWrapper(
NodeDataLoader(
graph.cpu(),
test_idx.cpu(),
eval_sampler,
batch_sampler=BatchSampler(len(test_idx), batch_size=eval_batch_size, shuffle=False),
num_workers=4,
)
)
final_test_score = evaluate(
args, model, eval_dataloader, labels, train_idx, val_idx, test_idx, criterion, evaluator_wrapper
)[2]
print("*" * 50)
print(f"Best val score: {best_val_score}, Final test score: {final_test_score}")
print("*" * 50)
if args.plot_curves:
fig = plt.figure(figsize=(24, 24))
ax = fig.gca()
ax.set_xticks(np.arange(0, args.n_epochs, 100))
ax.set_yticks(np.linspace(0, 1.0, 101))
ax.tick_params(labeltop=True, labelright=True)
for y, label in zip([train_scores, val_scores, test_scores], ["train score", "val score", "test score"]):
plt.plot(range(1, args.n_epochs + 1, args.log_every), y, label=label, linewidth=1)
ax.xaxis.set_major_locator(MultipleLocator(10))
ax.xaxis.set_minor_locator(AutoMinorLocator(1))
ax.yaxis.set_major_locator(MultipleLocator(0.01))
ax.yaxis.set_minor_locator(AutoMinorLocator(2))
plt.grid(which="major", color="red", linestyle="dotted")
plt.grid(which="minor", color="orange", linestyle="dotted")
plt.legend()
plt.tight_layout()
plt.savefig(f"gat_score_{n_running}.png")
fig = plt.figure(figsize=(24, 24))
ax = fig.gca()
ax.set_xticks(np.arange(0, args.n_epochs, 100))
ax.tick_params(labeltop=True, labelright=True)
for y, label in zip(
[losses, train_losses, val_losses, test_losses], ["loss", "train loss", "val loss", "test loss"]
):
plt.plot(range(1, args.n_epochs + 1, args.log_every), y, label=label, linewidth=1)
ax.xaxis.set_major_locator(MultipleLocator(10))
ax.xaxis.set_minor_locator(AutoMinorLocator(1))
ax.yaxis.set_major_locator(MultipleLocator(0.1))
ax.yaxis.set_minor_locator(AutoMinorLocator(5))
plt.grid(which="major", color="red", linestyle="dotted")
plt.grid(which="minor", color="orange", linestyle="dotted")
plt.legend()
plt.tight_layout()
plt.savefig(f"gat_loss_{n_running}.png")
return best_val_score, final_test_score
def count_parameters(args):
model = gen_model(args)
return sum([np.prod(p.size()) for p in model.parameters() if p.requires_grad])
def main():
global device
argparser = argparse.ArgumentParser(
"GAT implementation on ogbn-products", formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
argparser.add_argument("--cpu", action="store_true", help="CPU mode. This option overrides '--gpu'.")
argparser.add_argument("--gpu", type=int, default=0, help="GPU device ID")
argparser.add_argument("--seed", type=int, default=0, help="seed")
argparser.add_argument("--n-runs", type=int, default=10, help="running times")
argparser.add_argument("--n-epochs", type=int, default=250, help="number of epochs")
argparser.add_argument(
"--use-labels", action="store_true", help="Use labels in the training set as input features."
)
argparser.add_argument("--n-label-iters", type=int, default=0, help="number of label iterations")
argparser.add_argument("--no-attn-dst", action="store_true", help="Don't use attn_dst.")
argparser.add_argument("--mask-rate", type=float, default=0.5, help="mask rate")
argparser.add_argument("--n-heads", type=int, default=4, help="number of heads")
argparser.add_argument("--lr", type=float, default=0.01, help="learning rate")
argparser.add_argument("--n-layers", type=int, default=3, help="number of layers")
argparser.add_argument("--n-hidden", type=int, default=120, help="number of hidden units")
argparser.add_argument("--dropout", type=float, default=0.5, help="dropout rate")
argparser.add_argument("--input-drop", type=float, default=0.1, help="input drop rate")
argparser.add_argument("--attn-dropout", type=float, default=0.0, help="attention drop rate")
argparser.add_argument("--edge-drop", type=float, default=0.1, help="edge drop rate")
argparser.add_argument("--wd", type=float, default=0, help="weight decay")
argparser.add_argument("--eval-every", type=int, default=2, help="log every EVAL_EVERY epochs")
argparser.add_argument(
"--estimation-mode", action="store_true", help="Estimate the score of test set for speed during training."
)
argparser.add_argument("--log-every", type=int, default=2, help="log every LOG_EVERY epochs")
argparser.add_argument("--plot-curves", action="store_true", help="plot learning curves")
args = argparser.parse_args()
if args.cpu:
device = torch.device("cpu")
else:
device = torch.device("cuda:%d" % args.gpu)
# load data & preprocess
graph, labels, train_idx, val_idx, test_idx, evaluator = load_data(dataset)
graph, labels = preprocess(graph, labels, train_idx)
labels, train_idx, val_idx, test_idx = map(lambda x: x.to(device), (labels, train_idx, val_idx, test_idx))
# run
val_scores, test_scores = [], []
for i in range(1, args.n_runs + 1):
seed(args.seed + i)
val_score, test_score = run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, i)
val_scores.append(val_score)
test_scores.append(test_score)
print(args)
print(f"Runned {args.n_runs} times")
print("Val scores:", val_scores)
print("Test scores:", test_scores)
print(f"Average val score: {np.mean(val_scores)} ± {np.std(val_scores)}")
print(f"Average test score: {np.mean(test_scores)} ± {np.std(test_scores)}")
print(f"Number of params: {count_parameters(args)}")
if args.estimation_mode:
print(
"WARNING: Estimation mode is enabled. The final test score is accurate, but not accurate during training time."
)
if __name__ == "__main__":
main()
# Namespace(attn_dropout=0.0, cpu=False, dropout=0.5, edge_drop=0.1, estimation_mode=True, eval_every=2, gpu=1, input_drop=0.1, log_every=2, lr=0.01, mask_rate=0.5, n_epochs=250, n_heads=4, n_hidden=120, n_label_iters=0, n_layers=3, n_runs=10, no_attn_dst=False, plot_curves=True, seed=0, use_labels=False, wd=0)
# Runned 10 times
# Val scores: [0.9326348447473489, 0.9330163008926073, 0.9327619967957684, 0.932355110240826, 0.9330163008926073, 0.9327365663860845, 0.9329145792538718, 0.9322788190117742, 0.9321516669633548, 0.9329908704829235]
# Test scores: [0.8147550191112792, 0.8115680737936217, 0.8128332725586069, 0.8134062268564646, 0.8118784993477448, 0.8145462613150566, 0.8151228304665284, 0.8115274066904614, 0.8108545920615103, 0.8094583548530088]
# Average val score: 0.9326857055667167 ± 0.00030580001557474636
# Average test score: 0.8125950537054282 ± 0.001765025824381352
# Number of params: 1065127
# Namespace(attn_dropout=0.0, cpu=False, dropout=0.5, edge_drop=0.1, estimation_mode=True, eval_every=2, gpu=0, input_drop=0.1, log_every=2, lr=0.01, mask_rate=0.5, n_epochs=250, n_heads=4, n_hidden=120, n_label_iters=0, n_layers=3, n_runs=5, no_attn_dst=True, plot_curves=True, seed=0, use_labels=False, wd=0)
# Runned 10 times
# Val scores: [0.9332451745797625, 0.9330417313022913, 0.9328128576151362, 0.9323296798311421, 0.9324568318795616, 0.9327874272054523, 0.9327619967957684, 0.9328128576151362, 0.9322025277827226, 0.9329400096635557]
# Test scores: [0.8103399272781824, 0.8115870517750965, 0.8107294277551171, 0.8115771109276573, 0.8130244079434601, 0.8094628734200265, 0.8105681149125815, 0.809217063374258, 0.8108085026779287, 0.8151549122923549]
# Average val score: 0.932739109427053 ± 0.0003061065079170266
# Average test score: 0.8112469392356664 ± 0.0016644261188834386
# Number of params: 1060887
# update time: 2020.11.02 17:33
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl import function as fn
from dgl.ops import edge_softmax
from dgl.utils import expand_as_pair
class GATConv(nn.Module):
def __init__(
self,
node_feats,
edge_feats,
out_feats,
n_heads=1,
attn_drop=0.0,
edge_drop=0.0,
negative_slope=0.2,
residual=True,
activation=None,
use_attn_dst=True,
allow_zero_in_degree=True,
use_symmetric_norm=False,
):
super(GATConv, self).__init__()
self._n_heads = n_heads
self._in_src_feats, self._in_dst_feats = expand_as_pair(node_feats)
self._out_feats = out_feats
self._allow_zero_in_degree = allow_zero_in_degree
self._use_symmetric_norm = use_symmetric_norm
# feat fc
self.src_fc = nn.Linear(self._in_src_feats, out_feats * n_heads, bias=False)
if residual:
self.dst_fc = nn.Linear(self._in_src_feats, out_feats * n_heads)
self.bias = None
else:
self.dst_fc = None
self.bias = nn.Parameter(out_feats * n_heads)
# attn fc
self.attn_src_fc = nn.Linear(self._in_src_feats, n_heads, bias=False)
if use_attn_dst:
self.attn_dst_fc = nn.Linear(self._in_src_feats, n_heads, bias=False)
else:
self.attn_dst_fc = None
if edge_feats > 0:
self.attn_edge_fc = nn.Linear(edge_feats, n_heads, bias=False)
else:
self.attn_edge_fc = None
self.attn_drop = nn.Dropout(attn_drop)
self.edge_drop = edge_drop
self.leaky_relu = nn.LeakyReLU(negative_slope, inplace=True)
self.activation = activation
self.reset_parameters()
def reset_parameters(self):
gain = nn.init.calculate_gain("relu")
nn.init.xavier_normal_(self.src_fc.weight, gain=gain)
if self.dst_fc is not None:
nn.init.xavier_normal_(self.dst_fc.weight, gain=gain)
nn.init.xavier_normal_(self.attn_src_fc.weight, gain=gain)
if self.attn_dst_fc is not None:
nn.init.xavier_normal_(self.attn_dst_fc.weight, gain=gain)
if self.attn_edge_fc is not None:
nn.init.xavier_normal_(self.attn_edge_fc.weight, gain=gain)
if self.bias is not None:
nn.init.zeros_(self.bias)
def set_allow_zero_in_degree(self, set_value):
self._allow_zero_in_degree = set_value
def forward(self, graph, feat_src, feat_edge=None):
with graph.local_scope():
if not self._allow_zero_in_degree:
if (graph.in_degrees() == 0).any():
assert False
if graph.is_block:
feat_dst = feat_src[: graph.number_of_dst_nodes()]
else:
feat_dst = feat_src
if self._use_symmetric_norm:
degs = graph.out_degrees().float().clamp(min=1)
norm = torch.pow(degs, -0.5)
shp = norm.shape + (1,) * (feat_src.dim() - 1)
norm = torch.reshape(norm, shp)
feat_src = feat_src * norm
feat_src_fc = self.src_fc(feat_src).view(-1, self._n_heads, self._out_feats)
feat_dst_fc = self.dst_fc(feat_dst).view(-1, self._n_heads, self._out_feats)
attn_src = self.attn_src_fc(feat_src).view(-1, self._n_heads, 1)
# NOTE: GAT paper uses "first concatenation then linear projection"
# to compute attention scores, while ours is "first projection then
# addition", the two approaches are mathematically equivalent:
# We decompose the weight vector a mentioned in the paper into
# [a_l || a_r], then
# a^T [Wh_i || Wh_j] = a_l Wh_i + a_r Wh_j
# Our implementation is much efficient because we do not need to
# save [Wh_i || Wh_j] on edges, which is not memory-efficient. Plus,
# addition could be optimized with DGL's built-in function u_add_v,
# which further speeds up computation and saves memory footprint.
graph.srcdata.update({"feat_src_fc": feat_src_fc, "attn_src": attn_src})
if self.attn_dst_fc is not None:
attn_dst = self.attn_dst_fc(feat_dst).view(-1, self._n_heads, 1)
graph.dstdata.update({"attn_dst": attn_dst})
graph.apply_edges(fn.u_add_v("attn_src", "attn_dst", "attn_node"))
else:
graph.apply_edges(fn.copy_u("attn_src", "attn_node"))
e = graph.edata["attn_node"]
if feat_edge is not None:
attn_edge = self.attn_edge_fc(feat_edge).view(-1, self._n_heads, 1)
graph.edata.update({"attn_edge": attn_edge})
e += graph.edata["attn_edge"]
e = self.leaky_relu(e)
if self.training and self.edge_drop > 0:
perm = torch.randperm(graph.number_of_edges(), device=e.device)
bound = int(graph.number_of_edges() * self.edge_drop)
eids = perm[bound:]
graph.edata["a"] = torch.zeros_like(e)
graph.edata["a"][eids] = self.attn_drop(edge_softmax(graph, e[eids], eids=eids))
else:
graph.edata["a"] = self.attn_drop(edge_softmax(graph, e))
# message passing
graph.update_all(fn.u_mul_e("feat_src_fc", "a", "m"), fn.sum("m", "feat_src_fc"))
rst = graph.dstdata["feat_src_fc"]
if self._use_symmetric_norm:
degs = graph.in_degrees().float().clamp(min=1)
norm = torch.pow(degs, 0.5)
shp = norm.shape + (1,) * (feat_dst.dim())
norm = torch.reshape(norm, shp)
rst = rst * norm
# residual
if self.dst_fc is not None:
rst += feat_dst_fc
else:
rst += self.bias
# activation
if self.activation is not None:
rst = self.activation(rst, inplace=True)
return rst
class GAT(nn.Module):
def __init__(
self,
node_feats,
edge_feats,
n_classes,
n_layers,
n_heads,
n_hidden,
edge_emb,
activation,
dropout,
input_drop,
attn_drop,
edge_drop,
use_attn_dst=True,
allow_zero_in_degree=False,
residual=False,
):
super().__init__()
self.n_layers = n_layers
self.n_heads = n_heads
self.n_hidden = n_hidden
self.n_classes = n_classes
self.convs = nn.ModuleList()
self.norms = nn.ModuleList()
self.node_encoder = nn.Linear(node_feats, n_hidden)
if edge_emb > 0:
self.edge_encoder = nn.ModuleList()
else:
self.edge_encoder = None
for i in range(n_layers):
in_hidden = n_heads * n_hidden if i > 0 else node_feats
out_hidden = n_hidden
if self.edge_encoder is not None:
self.edge_encoder.append(nn.Linear(edge_feats, edge_emb))
self.convs.append(
GATConv(
in_hidden,
edge_emb,
out_hidden,
n_heads=n_heads,
attn_drop=attn_drop,
edge_drop=edge_drop,
use_attn_dst=use_attn_dst,
allow_zero_in_degree=allow_zero_in_degree,
)
)
self.norms.append(nn.BatchNorm1d(n_heads * out_hidden))
self.pred_linear = nn.Linear(n_heads * n_hidden, n_classes)
self.input_drop = nn.Dropout(input_drop)
self.dropout = nn.Dropout(dropout)
self.activation = activation
self.residual = residual
def forward(self, g, inference=False):
if not isinstance(g, list):
subgraphs = [g] * self.n_layers
else:
subgraphs = g
h = subgraphs[0].srcdata["feat"]
h = self.input_drop(h)
h_last = None
for i in range(self.n_layers):
if self.edge_encoder is not None:
efeat = subgraphs[i].edata["feat"]
efeat_emb = self.edge_encoder[i](efeat)
efeat_emb = F.relu(efeat_emb, inplace=True)
else:
efeat_emb = None
h = self.convs[i](subgraphs[i], h, efeat_emb).flatten(1, -1)
if self.residual and h_last is not None:
h += h_last[: h.shape[0], :]
h_last = h
h = self.norms[i](h)
h = self.activation(h, inplace=True)
h = self.dropout(h)
if inference:
torch.cuda.empty_cache()
h = self.pred_linear(h)
return h
class MLP(nn.Module):
def __init__(
self, in_feats, n_classes, n_layers, n_hidden, activation, dropout=0.0, input_drop=0.0, residual=False,
):
super().__init__()
self.n_layers = n_layers
self.n_hidden = n_hidden
self.n_classes = n_classes
self.linears = nn.ModuleList()
self.norms = nn.ModuleList()
for i in range(n_layers):
in_hidden = n_hidden if i > 0 else in_feats
out_hidden = n_hidden if i < n_layers - 1 else n_classes
self.linears.append(nn.Linear(in_hidden, out_hidden))
if i < n_layers - 1:
self.norms.append(nn.BatchNorm1d(out_hidden))
self.activation = activation
self.input_drop = nn.Dropout(input_drop)
self.dropout = nn.Dropout(dropout)
self.residual = residual
def forward(self, h):
h = self.input_drop(h)
h_last = None
for i in range(self.n_layers):
h = self.linears[i](h)
if self.residual and 0 < i < self.n_layers - 1:
h += h_last
h_last = h
if i < self.n_layers - 1:
h = self.norms[i](h)
h = self.activation(h, inplace=True)
h = self.dropout(h)
return h
import torch
class DataLoaderWrapper(object):
def __init__(self, dataloader):
self.iter = iter(dataloader)
def __iter__(self):
return self
def __next__(self):
try:
return next(self.iter)
except Exception:
raise StopIteration() from None
class BatchSampler(object):
def __init__(self, n, batch_size, shuffle=False):
self.n = n
self.batch_size = batch_size
self.shuffle = shuffle
def __iter__(self):
if not self.shuffle:
perm = torch.arange(start=0, end=self.n)
while True:
if self.shuffle:
perm = torch.randperm(self.n)
shuf = perm.split(self.batch_size)
for shuf_batch in shuf:
yield shuf_batch
yield None
# DGL examples for ogbn-products
Requires DGL 0.5 or later versions.
For the score of `MLP`, run the following command and you should directly see the result.
```bash
python3 mlp.py --eval-last
```
## Results
Here are the results over 10 runs.
| Method | Validation Accuracy | Test Accuracy | #Parameters |
|:------:|:-------------------:|:---------------:|:-----------:|
| MLP | 0.7841 ± 0.0014 | 0.6320 ± 0.0013 | 535,727 |
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import argparse
import math
import random
import time
from collections import OrderedDict
import dgl.function as fn
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
from dgl.dataloading import MultiLayerFullNeighborSampler, MultiLayerNeighborSampler
from dgl.dataloading.pytorch import NodeDataLoader
from matplotlib.ticker import AutoMinorLocator, MultipleLocator
from ogb.nodeproppred import DglNodePropPredDataset, Evaluator
from torch import nn
from tqdm import tqdm
from models import MLP
from utils import BatchSampler, DataLoaderWrapper
epsilon = 1 - math.log(2)
device = None
dataset = "ogbn-products"
n_node_feats, n_edge_feats, n_classes = 0, 0, 0
def seed(seed=0):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def load_data(dataset):
data = DglNodePropPredDataset(name=dataset)
evaluator = Evaluator(name=dataset)
splitted_idx = data.get_idx_split()
train_idx, val_idx, test_idx = splitted_idx["train"], splitted_idx["valid"], splitted_idx["test"]
graph, labels = data[0]
graph.ndata["labels"] = labels
return graph, labels, train_idx, val_idx, test_idx, evaluator
def preprocess(graph, labels):
global n_node_feats, n_classes
n_node_feats = graph.ndata["feat"].shape[1]
n_classes = (labels.max() + 1).item()
# graph = graph.remove_self_loop().add_self_loop()
n_node_feats = graph.ndata["feat"].shape[-1]
return graph, labels
def gen_model(args):
model = MLP(
n_node_feats,
n_classes,
n_layers=args.n_layers,
n_hidden=args.n_hidden,
activation=F.relu,
dropout=args.dropout,
input_drop=args.input_drop,
residual=False,
)
return model
def custom_loss_function(x, labels):
y = F.cross_entropy(x, labels[:, 0], reduction="none")
y = torch.log(epsilon + y) - math.log(epsilon)
return torch.mean(y)
def train(args, model, dataloader, labels, train_idx, criterion, optimizer, evaluator):
model.train()
loss_sum, total = 0, 0
preds = torch.zeros(labels.shape[0], n_classes)
for _input_nodes, output_nodes, subgraphs in dataloader:
subgraphs = [b.to(device) for b in subgraphs]
new_train_idx = list(range(len(output_nodes)))
pred = model(subgraphs[0].srcdata["feat"])
preds[output_nodes] = pred.cpu().detach()
loss = criterion(pred[new_train_idx], labels[output_nodes][new_train_idx])
optimizer.zero_grad()
loss.backward()
optimizer.step()
count = len(new_train_idx)
loss_sum += loss.item() * count
total += count
return (
loss_sum / total,
evaluator(preds[train_idx], labels[train_idx]),
)
@torch.no_grad()
def evaluate(args, model, dataloader, labels, train_idx, val_idx, test_idx, criterion, evaluator):
model.eval()
preds = torch.zeros(labels.shape[0], n_classes, device=device)
eval_times = 1 # Due to the limitation of memory capacity, we calculate the average of logits 'eval_times' times.
for _ in range(eval_times):
for _input_nodes, output_nodes, subgraphs in dataloader:
subgraphs = [b.to(device) for b in subgraphs]
pred = model(subgraphs[0].srcdata["feat"])
preds[output_nodes] = pred
preds /= eval_times
train_loss = criterion(preds[train_idx], labels[train_idx]).item()
val_loss = criterion(preds[val_idx], labels[val_idx]).item()
test_loss = criterion(preds[test_idx], labels[test_idx]).item()
return (
evaluator(preds[train_idx], labels[train_idx]),
evaluator(preds[val_idx], labels[val_idx]),
evaluator(preds[test_idx], labels[test_idx]),
train_loss,
val_loss,
test_loss,
)
def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running):
evaluator_wrapper = lambda pred, labels: evaluator.eval(
{"y_pred": pred.argmax(dim=-1, keepdim=True), "y_true": labels}
)["acc"]
criterion = custom_loss_function
train_batch_size = 4096
train_sampler = MultiLayerNeighborSampler([0 for _ in range(args.n_layers)]) # no not sample neighbors
train_dataloader = DataLoaderWrapper(
NodeDataLoader(
graph.cpu(),
train_idx.cpu(),
train_sampler,
batch_sampler=BatchSampler(len(train_idx), batch_size=train_batch_size, shuffle=True),
num_workers=4,
)
)
eval_batch_size = 4096
eval_sampler = MultiLayerNeighborSampler([0 for _ in range(args.n_layers)]) # no not sample neighbors
if args.eval_last:
eval_idx = torch.cat([train_idx.cpu(), val_idx.cpu()])
else:
eval_idx = torch.cat([train_idx.cpu(), val_idx.cpu(), test_idx.cpu()])
eval_dataloader = DataLoaderWrapper(
NodeDataLoader(
graph.cpu(),
eval_idx,
eval_sampler,
batch_sampler=BatchSampler(len(eval_idx), batch_size=eval_batch_size, shuffle=False),
num_workers=4,
)
)
model = gen_model(args).to(device)
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wd)
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode="max", factor=0.7, patience=20, verbose=True, min_lr=1e-4
)
best_model_state_dict = None
total_time = 0
val_score, best_val_score, final_test_score = 0, 0, 0
scores, train_scores, val_scores, test_scores = [], [], [], []
losses, train_losses, val_losses, test_losses = [], [], [], []
for epoch in range(1, args.n_epochs + 1):
tic = time.time()
loss, score = train(args, model, train_dataloader, labels, train_idx, criterion, optimizer, evaluator_wrapper)
toc = time.time()
total_time += toc - tic
if epoch % args.eval_every == 0 or epoch % args.log_every == 0:
train_score, val_score, test_score, train_loss, val_loss, test_loss = evaluate(
args, model, eval_dataloader, labels, train_idx, val_idx, test_idx, criterion, evaluator_wrapper
)
if val_score > best_val_score:
best_val_score = val_score
final_test_score = test_score
if args.eval_last:
best_model_state_dict = {k: v.to("cpu") for k, v in model.state_dict().items()}
best_model_state_dict = OrderedDict(best_model_state_dict)
if epoch % args.log_every == 0:
print(
f"Run: {n_running}/{args.n_runs}, Epoch: {epoch}/{args.n_epochs}, Average epoch time: {total_time / epoch}"
)
print(
f"Loss: {loss:.4f}, Score: {score:.4f}\n"
f"Train/Val/Test loss: {train_loss:.4f}/{val_loss:.4f}/{test_loss:.4f}\n"
f"Train/Val/Test/Best val/Final test score: {train_score:.4f}/{val_score:.4f}/{test_score:.4f}/{best_val_score:.4f}/{final_test_score:.4f}"
)
for l, e in zip(
[scores, train_scores, val_scores, test_scores, losses, train_losses, val_losses, test_losses],
[score, train_score, val_score, test_score, loss, train_loss, val_loss, test_loss],
):
l.append(e)
lr_scheduler.step(val_score)
if args.eval_last:
model.load_state_dict(best_model_state_dict)
eval_dataloader = DataLoaderWrapper(
NodeDataLoader(
graph.cpu(),
test_idx.cpu(),
eval_sampler,
batch_sampler=BatchSampler(len(test_idx), batch_size=eval_batch_size, shuffle=False),
num_workers=4,
)
)
final_test_score = evaluate(
args, model, eval_dataloader, labels, train_idx, val_idx, test_idx, criterion, evaluator_wrapper
)[2]
print("*" * 50)
print(f"Average epoch time: {total_time / args.n_epochs}, Test score: {final_test_score}")
if args.plot_curves:
fig = plt.figure(figsize=(24, 24))
ax = fig.gca()
ax.set_xticks(np.arange(0, args.n_epochs, 100))
ax.set_yticks(np.linspace(0, 1.0, 101))
ax.tick_params(labeltop=True, labelright=True)
for y, label in zip([train_scores, val_scores, test_scores], ["train score", "val score", "test score"]):
plt.plot(range(1, args.n_epochs + 1, args.log_every), y, label=label, linewidth=1)
ax.xaxis.set_major_locator(MultipleLocator(20))
ax.xaxis.set_minor_locator(AutoMinorLocator(1))
ax.yaxis.set_major_locator(MultipleLocator(0.01))
ax.yaxis.set_minor_locator(AutoMinorLocator(2))
plt.grid(which="major", color="red", linestyle="dotted")
plt.grid(which="minor", color="orange", linestyle="dotted")
plt.legend()
plt.tight_layout()
plt.savefig(f"gat_score_{n_running}.png")
fig = plt.figure(figsize=(24, 24))
ax = fig.gca()
ax.set_xticks(np.arange(0, args.n_epochs, 100))
ax.tick_params(labeltop=True, labelright=True)
for y, label in zip(
[losses, train_losses, val_losses, test_losses], ["loss", "train loss", "val loss", "test loss"]
):
plt.plot(range(1, args.n_epochs + 1, args.log_every), y, label=label, linewidth=1)
ax.xaxis.set_major_locator(MultipleLocator(20))
ax.xaxis.set_minor_locator(AutoMinorLocator(1))
ax.yaxis.set_major_locator(MultipleLocator(0.1))
ax.yaxis.set_minor_locator(AutoMinorLocator(5))
plt.grid(which="major", color="red", linestyle="dotted")
plt.grid(which="minor", color="orange", linestyle="dotted")
plt.legend()
plt.tight_layout()
plt.savefig(f"gat_loss_{n_running}.png")
return best_val_score, final_test_score
def count_parameters(args):
model = gen_model(args)
return sum([np.prod(p.size()) for p in model.parameters() if p.requires_grad])
def main():
global device
argparser = argparse.ArgumentParser("GAT on OGBN-Proteins", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
argparser.add_argument("--cpu", action="store_true", help="CPU mode. This option overrides '--gpu'.")
argparser.add_argument("--gpu", type=int, default=0, help="GPU device ID.")
argparser.add_argument("--seed", type=int, help="seed", default=0)
argparser.add_argument("--n-runs", type=int, default=10)
argparser.add_argument("--n-epochs", type=int, default=500)
argparser.add_argument("--lr", type=float, default=0.01)
argparser.add_argument("--n-layers", type=int, default=4)
argparser.add_argument("--n-hidden", type=int, default=480)
argparser.add_argument("--dropout", type=float, default=0.2)
argparser.add_argument("--input-drop", type=float, default=0)
argparser.add_argument("--wd", type=float, default=0)
argparser.add_argument("--estimation-mode", action="store_true", help="Estimate the score of test set for speed.")
argparser.add_argument("--eval-last", action="store_true", help="Evaluate the score of test set at last.")
argparser.add_argument("--eval-every", type=int, default=1)
argparser.add_argument("--log-every", type=int, default=1)
argparser.add_argument("--plot-curves", action="store_true")
args = argparser.parse_args()
if args.cpu:
device = torch.device("cpu")
else:
device = torch.device("cuda:%d" % args.gpu)
if args.estimation_mode:
print("WARNING: Estimation mode is enabled. The test score is not accurate.")
seed(args.seed)
graph, labels, train_idx, val_idx, test_idx, evaluator = load_data(dataset)
graph, labels = preprocess(graph, labels)
graph.create_formats_()
# graph = graph.to(device)
labels = labels.to(device)
train_idx = train_idx.to(device)
val_idx = val_idx.to(device)
test_idx = test_idx.to(device)
if args.estimation_mode:
test_idx = test_idx[torch.arange(start=0, end=len(test_idx), step=50)]
val_scores, test_scores = [], []
for i in range(1, args.n_runs + 1):
val_score, test_score = run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, i)
val_scores.append(val_score)
test_scores.append(test_score)
print(args)
print(f"Runned {args.n_runs} times")
print("Val scores:", val_scores)
print("Test scores:", test_scores)
print(f"Average val score: {np.mean(val_scores)} ± {np.std(val_scores)}")
print(f"Average test score: {np.mean(test_scores)} ± {np.std(test_scores)}")
print(f"Number of params: {count_parameters(args)}")
if args.estimation_mode:
print("WARNING: Estimation mode is enabled. The test score is not accurate.")
if __name__ == "__main__":
main()
# Namespace(cpu=False, dropout=0.2, estimation_mode=False, eval_every=1, eval_last=True, gpu=2, input_drop=0, log_every=1, lr=0.01, n_epochs=500, n_hidden=480, n_layers=4, n_runs=10, plot_curves=True, seed=0, wd=0)
# Runned 10 times
# Val scores: [0.7846298603870508, 0.7811713246700405, 0.7828751621188618, 0.7839941001449533, 0.7843501258805279, 0.7841466826030568, 0.7846298603870508, 0.7865880019327112, 0.7832057574447524, 0.7851384685807289]
# Test scores: [0.6318660190656417, 0.6304137516261193, 0.6329961126767946, 0.6312885462007662, 0.6340624944929965, 0.6301507710256831, 0.6314534738969161, 0.6334637843631373, 0.6312465235275007, 0.6329857199726536]
# Average val score: 0.7840729344149735 ± 0.0013702460721628086
# Average test score: 0.6319927196848208 ± 0.001252448369121226
# Number of params: 535727
import torch
import torch.nn as nn
import torch.nn.functional as F
class MLP(nn.Module):
def __init__(
self, in_feats, n_classes, n_layers, n_hidden, activation, dropout=0.0, input_drop=0.0, residual=False,
):
super().__init__()
self.n_layers = n_layers
self.n_hidden = n_hidden
self.n_classes = n_classes
self.linears = nn.ModuleList()
self.norms = nn.ModuleList()
for i in range(n_layers):
in_hidden = n_hidden if i > 0 else in_feats
out_hidden = n_hidden if i < n_layers - 1 else n_classes
self.linears.append(nn.Linear(in_hidden, out_hidden))
if i < n_layers - 1:
self.norms.append(nn.BatchNorm1d(out_hidden))
self.activation = activation
self.input_drop = nn.Dropout(input_drop)
self.dropout = nn.Dropout(dropout)
self.residual = residual
def forward(self, h):
h = self.input_drop(h)
h_last = None
for i in range(self.n_layers):
h = self.linears[i](h)
if self.residual and 0 < i < self.n_layers - 1:
h += h_last
h_last = h
if i < self.n_layers - 1:
h = self.norms[i](h)
h = self.activation(h, inplace=True)
h = self.dropout(h)
return h
import torch
class DataLoaderWrapper(object):
def __init__(self, dataloader):
self.iter = iter(dataloader)
def __iter__(self):
return self
def __next__(self):
try:
return next(self.iter)
except Exception:
raise StopIteration() from None
class BatchSampler(object):
def __init__(self, n, batch_size, shuffle=False):
self.n = n
self.batch_size = batch_size
self.shuffle = shuffle
def __iter__(self):
while True:
if self.shuffle:
perm = torch.randperm(self.n)
else:
perm = torch.arange(start=0, end=self.n)
shuf = perm.split(self.batch_size)
for shuf_batch in shuf:
yield shuf_batch
yield None
# DGL for ogbn-proteins # DGL for ogbn-proteins
## Models ## GAT
DGL implementation of GAT for [ogbn-proteins](https://ogb.stanford.edu/docs/nodeprop/). Using some of the techniques from *Bag of Tricks for Node Classification with Graph Neural Networks* ([https://arxiv.org/abs/2103.13355](https://arxiv.org/abs/2103.13355)).
Requires DGL 0.5 or later versions.
### Usage
For the best score, run `gat.py` and you should directly see the result.
```bash
python3 gat.py
```
For the score of `GAT+labels`, run `gat.py` with `--use-labels` enabled and you should directly see the result.
```bash
python3 gat.py --use-labels
```
### Results
Here are the results over 10 runs.
| Method | Validation ROC-AUC | Test ROC-AUC | #Parameters |
|:----------:|:------------------:|:---------------:|:-----------:|
| GAT | 0.9276 ± 0.0007 | 0.8747 ± 0.0016 | 2,475,232 |
| GAT+labels | 0.9280 ± 0.0008 | 0.8765 ± 0.0008 | 2,484,192 |
## MWE-GCN and MWE-DGCN
### Models
[MWE-GCN and MWE-DGCN](https://cims.nyu.edu/~chenzh/files/GCN_with_edge_weights.pdf) are GCN models designed for graphs whose edges contain multi-dimensional edge weights that indicate the strengths of the relations represented by the edges. [MWE-GCN and MWE-DGCN](https://cims.nyu.edu/~chenzh/files/GCN_with_edge_weights.pdf) are GCN models designed for graphs whose edges contain multi-dimensional edge weights that indicate the strengths of the relations represented by the edges.
## Dependencies ### Dependencies
- DGL 0.4.3 - DGL 0.5.2
- PyTorch 1.4.0 - PyTorch 1.4.0
- OGB 1.2.0 - OGB 1.2.0
- Tensorboard 2.1.1 - Tensorboard 2.1.1
## Usage ### Usage
To use MWE-GCN: To use MWE-GCN:
```python ```python
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import argparse
import os
import random
import sys
import time
import dgl
import dgl.function as fn
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
from dgl.dataloading import MultiLayerFullNeighborSampler, MultiLayerNeighborSampler
from dgl.dataloading.pytorch import NodeDataLoader
from matplotlib.ticker import AutoMinorLocator, MultipleLocator
from ogb.nodeproppred import DglNodePropPredDataset, Evaluator
from torch import nn
from models import GAT
from utils import BatchSampler, DataLoaderWrapper
device = None
dataset = "ogbn-proteins"
n_node_feats, n_edge_feats, n_classes = 0, 8, 112
def seed(seed=0):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
dgl.random.seed(seed)
def load_data(dataset):
data = DglNodePropPredDataset(name=dataset)
evaluator = Evaluator(name=dataset)
splitted_idx = data.get_idx_split()
train_idx, val_idx, test_idx = splitted_idx["train"], splitted_idx["valid"], splitted_idx["test"]
graph, labels = data[0]
graph.ndata["labels"] = labels
return graph, labels, train_idx, val_idx, test_idx, evaluator
def preprocess(graph, labels, train_idx):
global n_node_feats
# The sum of the weights of adjacent edges is used as node features.
graph.update_all(fn.copy_e("feat", "feat_copy"), fn.sum("feat_copy", "feat"))
n_node_feats = graph.ndata["feat"].shape[-1]
# Only the labels in the training set are used as features, while others are filled with zeros.
graph.ndata["train_labels_onehot"] = torch.zeros(graph.number_of_nodes(), n_classes)
graph.ndata["train_labels_onehot"][train_idx, labels[train_idx, 0]] = 1
graph.ndata["deg"] = graph.out_degrees().float().clamp(min=1)
graph.create_formats_()
return graph, labels
def gen_model(args):
if args.use_labels:
n_node_feats_ = n_node_feats + n_classes
else:
n_node_feats_ = n_node_feats
model = GAT(
n_node_feats_,
n_edge_feats,
n_classes,
n_layers=args.n_layers,
n_heads=args.n_heads,
n_hidden=args.n_hidden,
edge_emb=16,
activation=F.relu,
dropout=args.dropout,
input_drop=args.input_drop,
attn_drop=args.attn_drop,
edge_drop=args.edge_drop,
use_attn_dst=not args.no_attn_dst,
)
return model
def add_labels(graph, idx):
feat = graph.srcdata["feat"]
train_labels_onehot = torch.zeros([feat.shape[0], n_classes], device=device)
train_labels_onehot[idx] = graph.srcdata["train_labels_onehot"][idx]
graph.srcdata["feat"] = torch.cat([feat, train_labels_onehot], dim=-1)
def train(args, model, dataloader, _labels, _train_idx, criterion, optimizer, _evaluator):
model.train()
loss_sum, total = 0, 0
for input_nodes, output_nodes, subgraphs in dataloader:
subgraphs = [b.to(device) for b in subgraphs]
new_train_idx = torch.arange(len(output_nodes), device=device)
if args.use_labels:
train_labels_idx = torch.arange(len(output_nodes), len(input_nodes), device=device)
train_pred_idx = new_train_idx
add_labels(subgraphs[0], train_labels_idx)
else:
train_pred_idx = new_train_idx
pred = model(subgraphs)
loss = criterion(pred[train_pred_idx], subgraphs[-1].dstdata["labels"][train_pred_idx].float())
optimizer.zero_grad()
loss.backward()
optimizer.step()
count = len(train_pred_idx)
loss_sum += loss.item() * count
total += count
# torch.cuda.empty_cache()
return loss_sum / total
@torch.no_grad()
def evaluate(args, model, dataloader, labels, train_idx, val_idx, test_idx, criterion, evaluator):
model.eval()
preds = torch.zeros(labels.shape).to(device)
# Due to the memory capacity constraints, we use sampling for inference and calculate the average of the predictions 'eval_times' times.
eval_times = 1
for _ in range(eval_times):
for input_nodes, output_nodes, subgraphs in dataloader:
subgraphs = [b.to(device) for b in subgraphs]
new_train_idx = list(range(len(input_nodes)))
if args.use_labels:
add_labels(subgraphs[0], new_train_idx)
pred = model(subgraphs)
preds[output_nodes] += pred
# torch.cuda.empty_cache()
preds /= eval_times
train_loss = criterion(preds[train_idx], labels[train_idx].float()).item()
val_loss = criterion(preds[val_idx], labels[val_idx].float()).item()
test_loss = criterion(preds[test_idx], labels[test_idx].float()).item()
return (
evaluator(preds[train_idx], labels[train_idx]),
evaluator(preds[val_idx], labels[val_idx]),
evaluator(preds[test_idx], labels[test_idx]),
train_loss,
val_loss,
test_loss,
preds,
)
def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running):
evaluator_wrapper = lambda pred, labels: evaluator.eval({"y_pred": pred, "y_true": labels})["rocauc"]
train_batch_size = (len(train_idx) + 9) // 10
# batch_size = len(train_idx)
train_sampler = MultiLayerNeighborSampler([32 for _ in range(args.n_layers)])
# sampler = MultiLayerFullNeighborSampler(args.n_layers)
train_dataloader = DataLoaderWrapper(
NodeDataLoader(
graph.cpu(),
train_idx.cpu(),
train_sampler,
batch_sampler=BatchSampler(len(train_idx), batch_size=train_batch_size),
num_workers=10,
)
)
eval_sampler = MultiLayerNeighborSampler([100 for _ in range(args.n_layers)])
# sampler = MultiLayerFullNeighborSampler(args.n_layers)
eval_dataloader = DataLoaderWrapper(
NodeDataLoader(
graph.cpu(),
torch.cat([train_idx.cpu(), val_idx.cpu(), test_idx.cpu()]),
eval_sampler,
batch_sampler=BatchSampler(graph.number_of_nodes(), batch_size=65536),
num_workers=10,
)
)
criterion = nn.BCEWithLogitsLoss()
model = gen_model(args).to(device)
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wd)
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="max", factor=0.75, patience=50, verbose=True)
total_time = 0
val_score, best_val_score, final_test_score = 0, 0, 0
train_scores, val_scores, test_scores = [], [], []
losses, train_losses, val_losses, test_losses = [], [], [], []
final_pred = None
for epoch in range(1, args.n_epochs + 1):
tic = time.time()
loss = train(args, model, train_dataloader, labels, train_idx, criterion, optimizer, evaluator_wrapper)
toc = time.time()
total_time += toc - tic
if epoch == args.n_epochs or epoch % args.eval_every == 0 or epoch % args.log_every == 0:
train_score, val_score, test_score, train_loss, val_loss, test_loss, pred = evaluate(
args, model, eval_dataloader, labels, train_idx, val_idx, test_idx, criterion, evaluator_wrapper
)
if val_score > best_val_score:
best_val_score = val_score
final_test_score = test_score
final_pred = pred
if epoch % args.log_every == 0:
print(
f"Run: {n_running}/{args.n_runs}, Epoch: {epoch}/{args.n_epochs}, Average epoch time: {total_time / epoch:.2f}s"
)
print(
f"Loss: {loss:.4f}\n"
f"Train/Val/Test loss: {train_loss:.4f}/{val_loss:.4f}/{test_loss:.4f}\n"
f"Train/Val/Test/Best val/Final test score: {train_score:.4f}/{val_score:.4f}/{test_score:.4f}/{best_val_score:.4f}/{final_test_score:.4f}"
)
for l, e in zip(
[train_scores, val_scores, test_scores, losses, train_losses, val_losses, test_losses],
[train_score, val_score, test_score, loss, train_loss, val_loss, test_loss],
):
l.append(e)
lr_scheduler.step(val_score)
print("*" * 50)
print(f"Best val score: {best_val_score}, Final test score: {final_test_score}")
print("*" * 50)
if args.plot:
fig = plt.figure(figsize=(24, 24))
ax = fig.gca()
ax.set_xticks(np.arange(0, args.n_epochs, 100))
ax.set_yticks(np.linspace(0, 1.0, 101))
ax.tick_params(labeltop=True, labelright=True)
for y, label in zip([train_scores, val_scores, test_scores], ["train score", "val score", "test score"]):
plt.plot(range(1, args.n_epochs + 1, args.log_every), y, label=label, linewidth=1)
ax.xaxis.set_major_locator(MultipleLocator(100))
ax.xaxis.set_minor_locator(AutoMinorLocator(1))
ax.yaxis.set_major_locator(MultipleLocator(0.01))
ax.yaxis.set_minor_locator(AutoMinorLocator(2))
plt.grid(which="major", color="red", linestyle="dotted")
plt.grid(which="minor", color="orange", linestyle="dotted")
plt.legend()
plt.tight_layout()
plt.savefig(f"gat_score_{n_running}.png")
fig = plt.figure(figsize=(24, 24))
ax = fig.gca()
ax.set_xticks(np.arange(0, args.n_epochs, 100))
ax.tick_params(labeltop=True, labelright=True)
for y, label in zip(
[losses, train_losses, val_losses, test_losses], ["loss", "train loss", "val loss", "test loss"]
):
plt.plot(range(1, args.n_epochs + 1, args.log_every), y, label=label, linewidth=1)
ax.xaxis.set_major_locator(MultipleLocator(100))
ax.xaxis.set_minor_locator(AutoMinorLocator(1))
ax.yaxis.set_major_locator(MultipleLocator(0.1))
ax.yaxis.set_minor_locator(AutoMinorLocator(5))
plt.grid(which="major", color="red", linestyle="dotted")
plt.grid(which="minor", color="orange", linestyle="dotted")
plt.legend()
plt.tight_layout()
plt.savefig(f"gat_loss_{n_running}.png")
if args.save_pred:
os.makedirs("./output", exist_ok=True)
torch.save(F.softmax(final_pred, dim=1), f"./output/{n_running}.pt")
return best_val_score, final_test_score
def count_parameters(args):
model = gen_model(args)
return sum([np.prod(p.size()) for p in model.parameters() if p.requires_grad])
def main():
global device
argparser = argparse.ArgumentParser(
"GAT implementation on ogbn-proteins", formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
argparser.add_argument("--cpu", action="store_true", help="CPU mode. This option overrides '--gpu'.")
argparser.add_argument("--gpu", type=int, default=0, help="GPU device ID")
argparser.add_argument("--seed", type=int, default=0, help="random seed")
argparser.add_argument("--n-runs", type=int, default=10, help="running times")
argparser.add_argument("--n-epochs", type=int, default=1200, help="number of epochs")
argparser.add_argument(
"--use-labels", action="store_true", help="Use labels in the training set as input features."
)
argparser.add_argument("--no-attn-dst", action="store_true", help="Don't use attn_dst.")
argparser.add_argument("--n-heads", type=int, default=6, help="number of heads")
argparser.add_argument("--lr", type=float, default=0.01, help="learning rate")
argparser.add_argument("--n-layers", type=int, default=6, help="number of layers")
argparser.add_argument("--n-hidden", type=int, default=80, help="number of hidden units")
argparser.add_argument("--dropout", type=float, default=0.25, help="dropout rate")
argparser.add_argument("--input-drop", type=float, default=0.1, help="input drop rate")
argparser.add_argument("--attn-drop", type=float, default=0.0, help="attention dropout rate")
argparser.add_argument("--edge-drop", type=float, default=0.1, help="edge drop rate")
argparser.add_argument("--wd", type=float, default=0, help="weight decay")
argparser.add_argument("--eval-every", type=int, default=5, help="evaluate every EVAL_EVERY epochs")
argparser.add_argument("--log-every", type=int, default=5, help="log every LOG_EVERY epochs")
argparser.add_argument("--plot", action="store_true", help="plot learning curves")
argparser.add_argument("--save-pred", action="store_true", help="save final predictions")
args = argparser.parse_args()
if args.cpu:
device = torch.device("cpu")
else:
device = torch.device(f"cuda:{args.gpu}")
# load data & preprocess
print("Loading data")
graph, labels, train_idx, val_idx, test_idx, evaluator = load_data(dataset)
print("Preprocessing")
graph, labels = preprocess(graph, labels, train_idx)
labels, train_idx, val_idx, test_idx = map(lambda x: x.to(device), (labels, train_idx, val_idx, test_idx))
# run
val_scores, test_scores = [], []
for i in range(args.n_runs):
print("Running", i)
seed(args.seed + i)
val_score, test_score = run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, i + 1)
val_scores.append(val_score)
test_scores.append(test_score)
print(" ".join(sys.argv))
print(args)
print(f"Runned {args.n_runs} times")
print("Val scores:", val_scores)
print("Test scores:", test_scores)
print(f"Average val score: {np.mean(val_scores)} ± {np.std(val_scores)}")
print(f"Average test score: {np.mean(test_scores)} ± {np.std(test_scores)}")
print(f"Number of params: {count_parameters(args)}")
if __name__ == "__main__":
main()
# Namespace(attn_drop=0.0, cpu=False, dropout=0.25, edge_drop=0.1, eval_every=5, gpu=6, input_drop=0.1, log_every=5, lr=0.01, n_epochs=1200, n_heads=6, n_hidden=80, n_layers=6, n_runs=10, no_attn_dst=False, plot=True, save_pred=False, seed=0, use_labels=False, wd=0)
# Runned 10 times
# Val scores: [0.927741031859485, 0.9272113161947824, 0.9271363901359605, 0.9275579074100136, 0.9264291968462317, 0.9275278541203443, 0.9286381790529751, 0.9288245051991526, 0.9269289529175155, 0.9278177920224489]
# Test scores: [0.8754403567694566, 0.8749781870941457, 0.8735933245353141, 0.8759835445000637, 0.8745950242855286, 0.8742530369108132, 0.8784892022402326, 0.873345314887444, 0.8724393129004984, 0.874077975765639]
# Average val score: 0.927581312575891 ± 0.0006953509986591492
# Average test score: 0.8747195279889135 ± 0.001593598488797452
# Number of params: 2475232
# Namespace(attn_drop=0.0, cpu=False, dropout=0.25, edge_drop=0.1, eval_every=5, gpu=7, input_drop=0.1, log_every=5, lr=0.01, n_epochs=1200, n_heads=6, n_hidden=80, n_layers=6, n_runs=10, no_attn_dst=False, plot=True, save_pred=False, seed=0, use_labels=True, wd=0)
# Runned 10 times
# Val scores: [0.9293776332568928, 0.9281066322254939, 0.9286775378440911, 0.9270252685136046, 0.9267937838323375, 0.9277731792338011, 0.9285615428437761, 0.9270819730221879, 0.9276822010553241, 0.9287115722177839]
# Test scores: [0.8761623033485811, 0.8773002619440896, 0.8756680817047869, 0.8751873860287073, 0.875781797307807, 0.8764533839446703, 0.8771202308989311, 0.8765888651476396, 0.8773581283481205, 0.8777751912293709]
# Average val score: 0.9279791324045293 ± 0.0008115348697502517
# Average test score: 0.8765395629902706 ± 0.0008016806017700173
# Number of params: 2484192
import os import os
import numpy as np
import time import time
import dgl.function as fn
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import dgl.function as fn
import torch.nn.functional as F import torch.nn.functional as F
from ogb.nodeproppred.dataset_dgl import DglNodePropPredDataset
from ogb.nodeproppred import Evaluator from ogb.nodeproppred import Evaluator
from ogb.nodeproppred.dataset_dgl import DglNodePropPredDataset
from torch.optim import Adam from torch.optim import Adam
from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import ReduceLROnPlateau from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.tensorboard import SummaryWriter
from utils import load_model, set_random_seed from utils import load_model, set_random_seed
def normalize_edge_weights(graph, device, num_ew_channels): def normalize_edge_weights(graph, device, num_ew_channels):
degs = graph.in_degrees().float() degs = graph.in_degrees().float()
degs = torch.clamp(degs, min=1) degs = torch.clamp(degs, min=1)
norm = torch.pow(degs, 0.5) norm = torch.pow(degs, 0.5)
norm = norm.to(args['device']) norm = norm.to(args["device"])
graph.ndata['norm'] = norm.unsqueeze(1) graph.ndata["norm"] = norm.unsqueeze(1)
graph.apply_edges(fn.e_div_u('feat', 'norm', 'feat')) graph.apply_edges(fn.e_div_u("feat", "norm", "feat"))
graph.apply_edges(fn.e_div_v('feat', 'norm', 'feat')) graph.apply_edges(fn.e_div_v("feat", "norm", "feat"))
for channel in range(num_ew_channels): for channel in range(num_ew_channels):
graph.edata['feat_' + str(channel)] = graph.edata['feat'][:, channel:channel+1] graph.edata["feat_" + str(channel)] = graph.edata["feat"][:, channel : channel + 1]
def run_a_train_epoch(graph, node_idx, model, criterion, optimizer, evaluator): def run_a_train_epoch(graph, node_idx, model, criterion, optimizer, evaluator):
model.train() model.train()
logits = model(graph)[node_idx] logits = model(graph)[node_idx]
labels = graph.ndata['labels'][node_idx] labels = graph.ndata["labels"][node_idx]
loss = criterion(logits, labels) loss = criterion(logits, labels)
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
...@@ -39,79 +40,71 @@ def run_a_train_epoch(graph, node_idx, model, criterion, optimizer, evaluator): ...@@ -39,79 +40,71 @@ def run_a_train_epoch(graph, node_idx, model, criterion, optimizer, evaluator):
labels = labels.cpu().numpy() labels = labels.cpu().numpy()
preds = logits.cpu().detach().numpy() preds = logits.cpu().detach().numpy()
return loss, evaluator.eval({"y_true": labels, "y_pred": preds})['rocauc'] return loss, evaluator.eval({"y_true": labels, "y_pred": preds})["rocauc"]
def run_an_eval_epoch(graph, splitted_idx, model, evaluator): def run_an_eval_epoch(graph, splitted_idx, model, evaluator):
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
logits = model(graph) logits = model(graph)
labels = graph.ndata['labels'].cpu().numpy() labels = graph.ndata["labels"].cpu().numpy()
preds = logits.cpu().detach().numpy() preds = logits.cpu().detach().numpy()
train_score = evaluator.eval({ train_score = evaluator.eval({"y_true": labels[splitted_idx["train"]], "y_pred": preds[splitted_idx["train"]]})
"y_true": labels[splitted_idx["train"]], val_score = evaluator.eval({"y_true": labels[splitted_idx["valid"]], "y_pred": preds[splitted_idx["valid"]]})
"y_pred": preds[splitted_idx["train"]] test_score = evaluator.eval({"y_true": labels[splitted_idx["test"]], "y_pred": preds[splitted_idx["test"]]})
})
val_score = evaluator.eval({ return train_score["rocauc"], val_score["rocauc"], test_score["rocauc"]
"y_true": labels[splitted_idx["valid"]],
"y_pred": preds[splitted_idx["valid"]]
})
test_score = evaluator.eval({
"y_true": labels[splitted_idx["test"]],
"y_pred": preds[splitted_idx["test"]]
})
return train_score['rocauc'], val_score['rocauc'], test_score['rocauc']
def main(args): def main(args):
print (args) print(args)
if (args['rand_seed'] > -1): if args["rand_seed"] > -1:
set_random_seed(args['rand_seed']) set_random_seed(args["rand_seed"])
dataset = DglNodePropPredDataset(name=args['dataset']) dataset = DglNodePropPredDataset(name=args["dataset"])
print(dataset.meta_info[args['dataset']]) print(dataset.meta_info)
splitted_idx = dataset.get_idx_split() splitted_idx = dataset.get_idx_split()
graph = dataset.graph[0] graph = dataset.graph[0]
graph.ndata['labels'] = dataset.labels.float().to(args['device']) graph.ndata["labels"] = dataset.labels.float().to(args["device"])
graph.edata['feat'] = graph.edata['feat'].float().to(args['device']) graph.edata["feat"] = graph.edata["feat"].float().to(args["device"])
if (args['ewnorm'] == 'both'): if args["ewnorm"] == "both":
print ('Symmetric normalization of edge weights by degree') print("Symmetric normalization of edge weights by degree")
normalize_edge_weights(graph, args['device'], args['num_ew_channels']) normalize_edge_weights(graph, args["device"], args["num_ew_channels"])
elif (args['ewnorm'] == 'none'): elif args["ewnorm"] == "none":
print ('Not normalizing edge weights') print("Not normalizing edge weights")
for channel in range(args['num_ew_channels']): for channel in range(args["num_ew_channels"]):
graph.edata['feat_' + str(channel)] = graph.edata['feat'][:, channel:channel+1] graph.edata["feat_" + str(channel)] = graph.edata["feat"][:, channel : channel + 1]
model = load_model(args).to(args['device']) model = load_model(args).to(args["device"])
optimizer = Adam(model.parameters(), lr=args['lr'], weight_decay=args['weight_decay']) optimizer = Adam(model.parameters(), lr=args["lr"], weight_decay=args["weight_decay"])
min_lr = 1e-3 min_lr = 1e-3
scheduler = ReduceLROnPlateau(optimizer, 'max', factor=0.7, patience=100, verbose=True, min_lr=min_lr) scheduler = ReduceLROnPlateau(optimizer, "max", factor=0.7, patience=100, verbose=True, min_lr=min_lr)
print ('scheduler min_lr', min_lr) print("scheduler min_lr", min_lr)
criterion = nn.BCEWithLogitsLoss() criterion = nn.BCEWithLogitsLoss()
evaluator = Evaluator(args['dataset']) evaluator = Evaluator(args["dataset"])
print ('model', args['model']) print("model", args["model"])
print ('n_layers', args['n_layers']) print("n_layers", args["n_layers"])
print ('hidden dim', args['hidden_feats']) print("hidden dim", args["hidden_feats"])
print ('lr', args['lr']) print("lr", args["lr"])
dur = [] dur = []
best_val_score = 0. best_val_score = 0.0
num_patient_epochs = 0 num_patient_epochs = 0
model_folder = './saved_models/' model_folder = "./saved_models/"
model_path = model_folder + str(args['exp_name']) + '_' + str(args['postfix']) model_path = model_folder + str(args["exp_name"]) + "_" + str(args["postfix"])
if not os.path.exists(model_folder): if not os.path.exists(model_folder):
os.makedirs(model_folder) os.makedirs(model_folder)
for epoch in range(1, args['num_epochs'] + 1): for epoch in range(1, args["num_epochs"] + 1):
if epoch >= 3: if epoch >= 3:
t0 = time.time() t0 = time.time()
loss, train_score = run_a_train_epoch(graph, splitted_idx["train"], model, loss, train_score = run_a_train_epoch(graph, splitted_idx["train"], model, criterion, optimizer, evaluator)
criterion, optimizer, evaluator)
if epoch >= 3: if epoch >= 3:
dur.append(time.time() - t0) dur.append(time.time() - t0)
...@@ -119,8 +112,7 @@ def main(args): ...@@ -119,8 +112,7 @@ def main(args):
else: else:
avg_time = None avg_time = None
train_score, val_score, test_score = run_an_eval_epoch(graph, splitted_idx, train_score, val_score, test_score = run_an_eval_epoch(graph, splitted_idx, model, evaluator)
model, evaluator)
scheduler.step(val_score) scheduler.step(val_score)
...@@ -132,50 +124,54 @@ def main(args): ...@@ -132,50 +124,54 @@ def main(args):
else: else:
num_patient_epochs += 1 num_patient_epochs += 1
print('Epoch {:d}, loss {:.4f}, train score {:.4f}, ' print(
'val score {:.4f}, avg time {}, num patient epochs {:d}'.format( "Epoch {:d}, loss {:.4f}, train score {:.4f}, "
epoch, loss, train_score, val_score, avg_time, num_patient_epochs)) "val score {:.4f}, avg time {}, num patient epochs {:d}".format(
epoch, loss, train_score, val_score, avg_time, num_patient_epochs
)
)
if num_patient_epochs == args['patience']: if num_patient_epochs == args["patience"]:
break break
model.load_state_dict(torch.load(model_path)) model.load_state_dict(torch.load(model_path))
train_score, val_score, test_score = run_an_eval_epoch(graph, splitted_idx, model, evaluator) train_score, val_score, test_score = run_an_eval_epoch(graph, splitted_idx, model, evaluator)
print('Train score {:.4f}'.format(train_score)) print("Train score {:.4f}".format(train_score))
print('Valid score {:.4f}'.format(val_score)) print("Valid score {:.4f}".format(val_score))
print('Test score {:.4f}'.format(test_score)) print("Test score {:.4f}".format(test_score))
with open("results.txt", "w") as f:
f.write("loss {:.4f}\n".format(loss))
f.write("Best validation rocauc {:.4f}\n".format(best_val_score))
f.write("Test rocauc {:.4f}\n".format(test_score))
with open('results.txt', 'w') as f: print(args)
f.write('loss {:.4f}\n'.format(loss))
f.write('Best validation rocauc {:.4f}\n'.format(best_val_score))
f.write('Test rocauc {:.4f}\n'.format(test_score))
print (args)
if __name__ == '__main__': if __name__ == "__main__":
import argparse import argparse
from configure import get_exp_configure from configure import get_exp_configure
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(description="OGB node property prediction with DGL using full graph training")
description='OGB node property prediction with DGL using full graph training') parser.add_argument(
parser.add_argument('-m', '--model', type=str, choices=['MWE-GCN', 'MWE-DGCN'], default='MWE-DGCN', "-m", "--model", type=str, choices=["MWE-GCN", "MWE-DGCN"], default="MWE-DGCN", help="Model to use"
help='Model to use') )
parser.add_argument('-c', '--cuda', type=str, default='none') parser.add_argument("-c", "--cuda", type=str, default="none")
parser.add_argument('--postfix', type=str, default='', help='a string appended to the file name of the saved model') parser.add_argument("--postfix", type=str, default="", help="a string appended to the file name of the saved model")
parser.add_argument('--rand_seed', type=int, default=-1, help='random seed for torch and numpy') parser.add_argument("--rand_seed", type=int, default=-1, help="random seed for torch and numpy")
parser.add_argument('--residual', action='store_true') parser.add_argument("--residual", action="store_true")
parser.add_argument('--ewnorm', type=str, default='none', choices=['none', 'both']) parser.add_argument("--ewnorm", type=str, default="none", choices=["none", "both"])
args = parser.parse_args().__dict__ args = parser.parse_args().__dict__
# Get experiment configuration # Get experiment configuration
args['dataset'] = 'ogbn-proteins' args["dataset"] = "ogbn-proteins"
args['exp_name'] = '_'.join([args['model'], args['dataset']]) args["exp_name"] = "_".join([args["model"], args["dataset"]])
args.update(get_exp_configure(args)) args.update(get_exp_configure(args))
if not (args['cuda'] == 'none'): if not (args["cuda"] == "none"):
args['device'] = torch.device('cuda: ' + str(args['cuda'])) args["device"] = torch.device("cuda: " + str(args["cuda"]))
else: else:
args['device'] = torch.device('cpu') args["device"] = torch.device("cpu")
main(args) main(args)
import math import math
from functools import partial
import dgl.function as fn
import dgl.nn.pytorch as dglnn
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import dgl.function as fn from dgl import function as fn
from dgl import DGLGraph from dgl._ffi.base import DGLError
from dgl.base import ALL
from dgl.nn.pytorch.utils import Identity
from dgl.ops import edge_softmax
from dgl.utils import expand_as_pair
from torch.nn import init
from torch.utils.checkpoint import checkpoint
class MWEConv(nn.Module): class MWEConv(nn.Module):
def __init__(self, def __init__(self, in_feats, out_feats, activation, bias=True, num_channels=8, aggr_mode="sum"):
in_feats,
out_feats,
activation,
bias=True,
num_channels=8,
aggr_mode='sum'):
super(MWEConv, self).__init__() super(MWEConv, self).__init__()
self.num_channels = num_channels self.num_channels = num_channels
self._in_feats = in_feats self._in_feats = in_feats
...@@ -26,18 +31,18 @@ class MWEConv(nn.Module): ...@@ -26,18 +31,18 @@ class MWEConv(nn.Module):
self.reset_parameters() self.reset_parameters()
self.activation = activation self.activation = activation
if (aggr_mode == 'concat'): if aggr_mode == "concat":
self.aggr_mode = 'concat' self.aggr_mode = "concat"
self.final = nn.Linear(out_feats * self.num_channels, out_feats) self.final = nn.Linear(out_feats * self.num_channels, out_feats)
elif (aggr_mode == 'sum'): elif aggr_mode == "sum":
self.aggr_mode = 'sum' self.aggr_mode = "sum"
self.final = nn.Linear(out_feats, out_feats) self.final = nn.Linear(out_feats, out_feats)
def reset_parameters(self): def reset_parameters(self):
stdv = 1. / math.sqrt(self.weight.size(1)) stdv = 1.0 / math.sqrt(self.weight.size(1))
self.weight.data.uniform_(-stdv, stdv) self.weight.data.uniform_(-stdv, stdv)
if self.bias is not None: if self.bias is not None:
stdv = 1. / math.sqrt(self.bias.size(0)) stdv = 1.0 / math.sqrt(self.bias.size(0))
self.bias.data.uniform_(-stdv, stdv) self.bias.data.uniform_(-stdv, stdv)
def forward(self, g, node_state_prev): def forward(self, g, node_state_prev):
...@@ -54,20 +59,22 @@ class MWEConv(nn.Module): ...@@ -54,20 +59,22 @@ class MWEConv(nn.Module):
for c in range(self.num_channels): for c in range(self.num_channels):
node_state_c = node_state node_state_c = node_state
if self._out_feats < self._in_feats: if self._out_feats < self._in_feats:
g.ndata['feat_' + str(c)] = torch.mm(node_state_c, self.weight[:, :, c]) g.ndata["feat_" + str(c)] = torch.mm(node_state_c, self.weight[:, :, c])
else: else:
g.ndata['feat_' + str(c)] = node_state_c g.ndata["feat_" + str(c)] = node_state_c
g.update_all(fn.src_mul_edge('feat_' + str(c), 'feat_' + str(c), 'm'), fn.sum('m', 'feat_' + str(c) + '_new')) g.update_all(
node_state_c = g.ndata.pop('feat_' + str(c) + '_new') fn.src_mul_edge("feat_" + str(c), "feat_" + str(c), "m"), fn.sum("m", "feat_" + str(c) + "_new")
)
node_state_c = g.ndata.pop("feat_" + str(c) + "_new")
if self._out_feats >= self._in_feats: if self._out_feats >= self._in_feats:
node_state_c = torch.mm(node_state_c, self.weight[:, :, c]) node_state_c = torch.mm(node_state_c, self.weight[:, :, c])
if self.bias is not None: if self.bias is not None:
node_state_c = node_state_c + self.bias[:, c] node_state_c = node_state_c + self.bias[:, c]
node_state_c = self.activation(node_state_c) node_state_c = self.activation(node_state_c)
new_node_states.append(node_state_c) new_node_states.append(node_state_c)
if (self.aggr_mode == 'sum'): if self.aggr_mode == "sum":
node_states = torch.stack(new_node_states, dim=1).sum(1) node_states = torch.stack(new_node_states, dim=1).sum(1)
elif (self.aggr_mode == 'concat'): elif self.aggr_mode == "concat":
node_states = torch.cat(new_node_states, dim=1) node_states = torch.cat(new_node_states, dim=1)
node_states = self.final(node_states) node_states = self.final(node_states)
...@@ -76,25 +83,15 @@ class MWEConv(nn.Module): ...@@ -76,25 +83,15 @@ class MWEConv(nn.Module):
class MWE_GCN(nn.Module): class MWE_GCN(nn.Module):
def __init__(self, def __init__(self, n_input, n_hidden, n_output, n_layers, activation, dropout, aggr_mode="sum", device="cpu"):
n_input,
n_hidden,
n_output,
n_layers,
activation,
dropout,
aggr_mode='sum',
device='cpu'):
super(MWE_GCN, self).__init__() super(MWE_GCN, self).__init__()
self.dropout = dropout self.dropout = dropout
self.activation = activation self.activation = activation
self.layers = nn.ModuleList() self.layers = nn.ModuleList()
self.layers.append(MWEConv(n_input, n_hidden, activation=activation, \ self.layers.append(MWEConv(n_input, n_hidden, activation=activation, aggr_mode=aggr_mode))
aggr_mode=aggr_mode))
for i in range(n_layers - 1): for i in range(n_layers - 1):
self.layers.append(MWEConv(n_hidden, n_hidden, activation=activation, \ self.layers.append(MWEConv(n_hidden, n_hidden, activation=activation, aggr_mode=aggr_mode))
aggr_mode=aggr_mode))
self.pred_out = nn.Linear(n_hidden, n_output) self.pred_out = nn.Linear(n_hidden, n_output)
self.device = device self.device = device
...@@ -105,23 +102,16 @@ class MWE_GCN(nn.Module): ...@@ -105,23 +102,16 @@ class MWE_GCN(nn.Module):
for layer in self.layers: for layer in self.layers:
node_state = F.dropout(node_state, p=self.dropout, training=self.training) node_state = F.dropout(node_state, p=self.dropout, training=self.training)
node_state = layer(g, node_state) node_state = layer(g, node_state)
node_state = self.activation(node_state) node_state = self.activation(node_state)
out = self.pred_out(node_state) out = self.pred_out(node_state)
return out return out
class MWE_DGCN(nn.Module): class MWE_DGCN(nn.Module):
def __init__(self, def __init__(
n_input, self, n_input, n_hidden, n_output, n_layers, activation, dropout, residual=False, aggr_mode="sum", device="cpu"
n_hidden, ):
n_output,
n_layers,
activation,
dropout,
residual=False,
aggr_mode='sum',
device='cpu'):
super(MWE_DGCN, self).__init__() super(MWE_DGCN, self).__init__()
self.n_layers = n_layers self.n_layers = n_layers
self.activation = activation self.activation = activation
...@@ -131,12 +121,10 @@ class MWE_DGCN(nn.Module): ...@@ -131,12 +121,10 @@ class MWE_DGCN(nn.Module):
self.layers = nn.ModuleList() self.layers = nn.ModuleList()
self.layer_norms = nn.ModuleList() self.layer_norms = nn.ModuleList()
self.layers.append(MWEConv(n_input, n_hidden, activation=activation, \ self.layers.append(MWEConv(n_input, n_hidden, activation=activation, aggr_mode=aggr_mode))
aggr_mode=aggr_mode))
for i in range(n_layers - 1): for i in range(n_layers - 1):
self.layers.append(MWEConv(n_hidden, n_hidden, activation=activation, \ self.layers.append(MWEConv(n_hidden, n_hidden, activation=activation, aggr_mode=aggr_mode))
aggr_mode=aggr_mode))
for i in range(n_layers): for i in range(n_layers):
self.layer_norms.append(nn.LayerNorm(n_hidden, elementwise_affine=True)) self.layer_norms.append(nn.LayerNorm(n_hidden, elementwise_affine=True))
...@@ -144,23 +132,22 @@ class MWE_DGCN(nn.Module): ...@@ -144,23 +132,22 @@ class MWE_DGCN(nn.Module):
self.pred_out = nn.Linear(n_hidden, n_output) self.pred_out = nn.Linear(n_hidden, n_output)
self.device = device self.device = device
def forward(self, g, node_state=None): def forward(self, g, node_state=None):
node_state = torch.ones(g.number_of_nodes(), 1).float().to(self.device) node_state = torch.ones(g.number_of_nodes(), 1).float().to(self.device)
node_state = self.layers[0](g, node_state) node_state = self.layers[0](g, node_state)
for layer in range(1, self.n_layers): for layer in range(1, self.n_layers):
node_state_new = self.layer_norms[layer-1](node_state) node_state_new = self.layer_norms[layer - 1](node_state)
node_state_new = self.activation(node_state_new) node_state_new = self.activation(node_state_new)
node_state_new = F.dropout(node_state_new, p=self.dropout, training=self.training) node_state_new = F.dropout(node_state_new, p=self.dropout, training=self.training)
if (self.residual == 'true'): if self.residual == "true":
node_state = node_state + self.layers[layer](g, node_state_new) node_state = node_state + self.layers[layer](g, node_state_new)
else: else:
node_state = self.layers[layer](g, node_state_new) node_state = self.layers[layer](g, node_state_new)
node_state = self.layer_norms[self.n_layers-1](node_state) node_state = self.layer_norms[self.n_layers - 1](node_state)
node_state = self.activation(node_state) node_state = self.activation(node_state)
node_state = F.dropout(node_state, p=self.dropout, training=self.training) node_state = F.dropout(node_state, p=self.dropout, training=self.training)
...@@ -169,3 +156,249 @@ class MWE_DGCN(nn.Module): ...@@ -169,3 +156,249 @@ class MWE_DGCN(nn.Module):
return out return out
class GATConv(nn.Module):
def __init__(
self,
node_feats,
edge_feats,
out_feats,
n_heads=1,
attn_drop=0.0,
edge_drop=0.0,
negative_slope=0.2,
residual=True,
activation=None,
use_attn_dst=True,
allow_zero_in_degree=True,
use_symmetric_norm=False,
):
super(GATConv, self).__init__()
self._n_heads = n_heads
self._in_src_feats, self._in_dst_feats = expand_as_pair(node_feats)
self._out_feats = out_feats
self._allow_zero_in_degree = allow_zero_in_degree
self._use_symmetric_norm = use_symmetric_norm
# feat fc
self.src_fc = nn.Linear(self._in_src_feats, out_feats * n_heads, bias=False)
if residual:
self.dst_fc = nn.Linear(self._in_src_feats, out_feats * n_heads)
self.bias = None
else:
self.dst_fc = None
self.bias = nn.Parameter(out_feats * n_heads)
# attn fc
self.attn_src_fc = nn.Linear(self._in_src_feats, n_heads, bias=False)
if use_attn_dst:
self.attn_dst_fc = nn.Linear(self._in_src_feats, n_heads, bias=False)
else:
self.attn_dst_fc = None
if edge_feats > 0:
self.attn_edge_fc = nn.Linear(edge_feats, n_heads, bias=False)
else:
self.attn_edge_fc = None
self.attn_drop = nn.Dropout(attn_drop)
self.edge_drop = edge_drop
self.leaky_relu = nn.LeakyReLU(negative_slope, inplace=True)
self.activation = activation
self.reset_parameters()
def reset_parameters(self):
gain = nn.init.calculate_gain("relu")
nn.init.xavier_normal_(self.src_fc.weight, gain=gain)
if self.dst_fc is not None:
nn.init.xavier_normal_(self.dst_fc.weight, gain=gain)
nn.init.xavier_normal_(self.attn_src_fc.weight, gain=gain)
if self.attn_dst_fc is not None:
nn.init.xavier_normal_(self.attn_dst_fc.weight, gain=gain)
if self.attn_edge_fc is not None:
nn.init.xavier_normal_(self.attn_edge_fc.weight, gain=gain)
if self.bias is not None:
nn.init.zeros_(self.bias)
def set_allow_zero_in_degree(self, set_value):
self._allow_zero_in_degree = set_value
def forward(self, graph, feat_src, feat_edge=None):
with graph.local_scope():
if not self._allow_zero_in_degree:
if (graph.in_degrees() == 0).any():
assert False
if graph.is_block:
feat_dst = feat_src[: graph.number_of_dst_nodes()]
else:
feat_dst = feat_src
if self._use_symmetric_norm:
degs = graph.srcdata["deg"]
# degs = graph.out_degrees().float().clamp(min=1)
norm = torch.pow(degs, -0.5)
shp = norm.shape + (1,) * (feat_src.dim() - 1)
norm = torch.reshape(norm, shp)
feat_src = feat_src * norm
feat_src_fc = self.src_fc(feat_src).view(-1, self._n_heads, self._out_feats)
feat_dst_fc = self.dst_fc(feat_dst).view(-1, self._n_heads, self._out_feats)
attn_src = self.attn_src_fc(feat_src).view(-1, self._n_heads, 1)
# NOTE: GAT paper uses "first concatenation then linear projection"
# to compute attention scores, while ours is "first projection then
# addition", the two approaches are mathematically equivalent:
# We decompose the weight vector a mentioned in the paper into
# [a_l || a_r], then
# a^T [Wh_i || Wh_j] = a_l Wh_i + a_r Wh_j
# Our implementation is much efficient because we do not need to
# save [Wh_i || Wh_j] on edges, which is not memory-efficient. Plus,
# addition could be optimized with DGL's built-in function u_add_v,
# which further speeds up computation and saves memory footprint.
graph.srcdata.update({"feat_src_fc": feat_src_fc, "attn_src": attn_src})
if self.attn_dst_fc is not None:
attn_dst = self.attn_dst_fc(feat_dst).view(-1, self._n_heads, 1)
graph.dstdata.update({"attn_dst": attn_dst})
graph.apply_edges(fn.u_add_v("attn_src", "attn_dst", "attn_node"))
else:
graph.apply_edges(fn.copy_u("attn_src", "attn_node"))
e = graph.edata["attn_node"]
if feat_edge is not None:
attn_edge = self.attn_edge_fc(feat_edge).view(-1, self._n_heads, 1)
graph.edata.update({"attn_edge": attn_edge})
e += graph.edata["attn_edge"]
e = self.leaky_relu(e)
if self.training and self.edge_drop > 0:
perm = torch.randperm(graph.number_of_edges(), device=e.device)
bound = int(graph.number_of_edges() * self.edge_drop)
eids = perm[bound:]
graph.edata["a"] = torch.zeros_like(e)
graph.edata["a"][eids] = self.attn_drop(edge_softmax(graph, e[eids], eids=eids))
else:
graph.edata["a"] = self.attn_drop(edge_softmax(graph, e))
# message passing
graph.update_all(fn.u_mul_e("feat_src_fc", "a", "m"), fn.sum("m", "feat_src_fc"))
rst = graph.dstdata["feat_src_fc"]
if self._use_symmetric_norm:
degs = graph.dstdata["deg"]
# degs = graph.in_degrees().float().clamp(min=1)
norm = torch.pow(degs, 0.5)
shp = norm.shape + (1,) * (feat_dst.dim())
norm = torch.reshape(norm, shp)
rst = rst * norm
# residual
if self.dst_fc is not None:
rst += feat_dst_fc
else:
rst += self.bias
# activation
if self.activation is not None:
rst = self.activation(rst, inplace=True)
return rst
class GAT(nn.Module):
def __init__(
self,
node_feats,
edge_feats,
n_classes,
n_layers,
n_heads,
n_hidden,
edge_emb,
activation,
dropout,
input_drop,
attn_drop,
edge_drop,
use_attn_dst=True,
allow_zero_in_degree=False,
):
super().__init__()
self.n_layers = n_layers
self.n_heads = n_heads
self.n_hidden = n_hidden
self.n_classes = n_classes
self.convs = nn.ModuleList()
self.norms = nn.ModuleList()
self.node_encoder = nn.Linear(node_feats, n_hidden)
if edge_emb > 0:
self.edge_encoder = nn.ModuleList()
for i in range(n_layers):
in_hidden = n_heads * n_hidden if i > 0 else n_hidden
out_hidden = n_hidden
# bias = i == n_layers - 1
if edge_emb > 0:
self.edge_encoder.append(nn.Linear(edge_feats, edge_emb))
self.convs.append(
GATConv(
in_hidden,
edge_emb,
out_hidden,
n_heads=n_heads,
attn_drop=attn_drop,
edge_drop=edge_drop,
use_attn_dst=use_attn_dst,
allow_zero_in_degree=allow_zero_in_degree,
use_symmetric_norm=False,
)
)
self.norms.append(nn.BatchNorm1d(n_heads * out_hidden))
self.pred_linear = nn.Linear(n_heads * n_hidden, n_classes)
self.input_drop = nn.Dropout(input_drop)
self.dropout = nn.Dropout(dropout)
self.activation = activation
def forward(self, g):
if not isinstance(g, list):
subgraphs = [g] * self.n_layers
else:
subgraphs = g
h = subgraphs[0].srcdata["feat"]
h = self.node_encoder(h)
h = F.relu(h, inplace=True)
h = self.input_drop(h)
h_last = None
for i in range(self.n_layers):
if self.edge_encoder is not None:
efeat = subgraphs[i].edata["feat"]
efeat_emb = self.edge_encoder[i](efeat)
efeat_emb = F.relu(efeat_emb, inplace=True)
else:
efeat_emb = None
h = self.convs[i](subgraphs[i], h, efeat_emb).flatten(1, -1)
if h_last is not None:
h += h_last[: h.shape[0], :]
h_last = h
h = self.norms[i](h)
h = self.activation(h, inplace=True)
h = self.dropout(h)
h = self.pred_linear(h)
return h
import numpy as np
import random import random
import numpy as np
import torch import torch
from models import MWE_GCN, MWE_DGCN
import torch.nn.functional as F import torch.nn.functional as F
from models import MWE_DGCN, MWE_GCN
def set_random_seed(seed): def set_random_seed(seed):
random.seed(seed) random.seed(seed)
np.random.seed(seed) np.random.seed(seed)
torch.manual_seed(seed) torch.manual_seed(seed)
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.manual_seed(seed) torch.cuda.manual_seed(seed)
print ('random seed set to be ' + str(seed)) print("random seed set to be " + str(seed))
def load_model(args): def load_model(args):
if args['model'] == 'MWE-GCN': if args["model"] == "MWE-GCN":
model = MWE_GCN( model = MWE_GCN(
n_input=args['in_feats'], n_input=args["in_feats"],
n_hidden=args['hidden_feats'], n_hidden=args["hidden_feats"],
n_output=args['out_feats'], n_output=args["out_feats"],
n_layers=args['n_layers'], n_layers=args["n_layers"],
activation=torch.nn.Tanh(), activation=torch.nn.Tanh(),
dropout=args['dropout'], dropout=args["dropout"],
aggr_mode=args['aggr_mode'], aggr_mode=args["aggr_mode"],
device=args['device']) device=args["device"],
elif args['model'] == 'MWE-DGCN': )
elif args["model"] == "MWE-DGCN":
model = MWE_DGCN( model = MWE_DGCN(
n_input=args['in_feats'], n_input=args["in_feats"],
n_hidden=args['hidden_feats'], n_hidden=args["hidden_feats"],
n_output=args['out_feats'], n_output=args["out_feats"],
n_layers=args['n_layers'], n_layers=args["n_layers"],
activation=torch.nn.ReLU(), activation=torch.nn.ReLU(),
dropout=args['dropout'], dropout=args["dropout"],
aggr_mode=args['aggr_mode'], aggr_mode=args["aggr_mode"],
residual=args['residual'], residual=args["residual"],
device=args['device']) device=args["device"],
)
else: else:
raise ValueError('Unexpected model {}'.format(args['model'])) raise ValueError("Unexpected model {}".format(args["model"]))
return model return model
class Logger(object): class Logger(object):
def __init__(self, runs, info=None): def __init__(self, runs, info=None):
self.info = info self.info = info
...@@ -53,11 +60,11 @@ class Logger(object): ...@@ -53,11 +60,11 @@ class Logger(object):
if run is not None: if run is not None:
result = 100 * torch.tensor(self.results[run]) result = 100 * torch.tensor(self.results[run])
argmax = result[:, 1].argmax().item() argmax = result[:, 1].argmax().item()
print(f'Run {run + 1:02d}:') print(f"Run {run + 1:02d}:")
print(f'Highest Train: {result[:, 0].max():.2f}') print(f"Highest Train: {result[:, 0].max():.2f}")
print(f'Highest Valid: {result[:, 1].max():.2f}') print(f"Highest Valid: {result[:, 1].max():.2f}")
print(f' Final Train: {result[argmax, 0]:.2f}') print(f" Final Train: {result[argmax, 0]:.2f}")
print(f' Final Test: {result[argmax, 2]:.2f}') print(f" Final Test: {result[argmax, 2]:.2f}")
else: else:
result = 100 * torch.tensor(self.results) result = 100 * torch.tensor(self.results)
...@@ -71,12 +78,39 @@ class Logger(object): ...@@ -71,12 +78,39 @@ class Logger(object):
best_result = torch.tensor(best_results) best_result = torch.tensor(best_results)
print(f'All runs:') print(f"All runs:")
r = best_result[:, 0] r = best_result[:, 0]
print(f'Highest Train: {r.mean():.2f} ± {r.std():.2f}') print(f"Highest Train: {r.mean():.2f} ± {r.std():.2f}")
r = best_result[:, 1] r = best_result[:, 1]
print(f'Highest Valid: {r.mean():.2f} ± {r.std():.2f}') print(f"Highest Valid: {r.mean():.2f} ± {r.std():.2f}")
r = best_result[:, 2] r = best_result[:, 2]
print(f' Final Train: {r.mean():.2f} ± {r.std():.2f}') print(f" Final Train: {r.mean():.2f} ± {r.std():.2f}")
r = best_result[:, 3] r = best_result[:, 3]
print(f' Final Test: {r.mean():.2f} ± {r.std():.2f}') print(f" Final Test: {r.mean():.2f} ± {r.std():.2f}")
class DataLoaderWrapper(object):
def __init__(self, dataloader):
self.iter = iter(dataloader)
def __iter__(self):
return self
def __next__(self):
try:
return next(self.iter)
except Exception:
raise StopIteration() from None
class BatchSampler(object):
def __init__(self, n, batch_size):
self.n = n
self.batch_size = batch_size
def __iter__(self):
while True:
shuf = torch.randperm(self.n).split(self.batch_size)
for shuf_batch in shuf:
yield shuf_batch
yield None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment