Unverified Commit 0b9df9d7 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Black auto fix. (#4652)


Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent f19f05ce
...@@ -2,39 +2,39 @@ ...@@ -2,39 +2,39 @@
import torch import torch
MWE_GCN_proteins = { MWE_GCN_proteins = {
'num_ew_channels': 8, "num_ew_channels": 8,
'num_epochs': 2000, "num_epochs": 2000,
'in_feats': 1, "in_feats": 1,
'hidden_feats': 10, "hidden_feats": 10,
'out_feats': 112, "out_feats": 112,
'n_layers': 3, "n_layers": 3,
'lr': 2e-2, "lr": 2e-2,
'weight_decay': 0, "weight_decay": 0,
'patience': 1000, "patience": 1000,
'dropout': 0.2, "dropout": 0.2,
'aggr_mode': 'sum', ## 'sum' or 'concat' for the aggregation across channels "aggr_mode": "sum", ## 'sum' or 'concat' for the aggregation across channels
'ewnorm': 'both' "ewnorm": "both",
} }
MWE_DGCN_proteins = { MWE_DGCN_proteins = {
'num_ew_channels': 8, "num_ew_channels": 8,
'num_epochs': 2000, "num_epochs": 2000,
'in_feats': 1, "in_feats": 1,
'hidden_feats': 10, "hidden_feats": 10,
'out_feats': 112, "out_feats": 112,
'n_layers': 2, "n_layers": 2,
'lr': 1e-2, "lr": 1e-2,
'weight_decay': 0, "weight_decay": 0,
'patience': 300, "patience": 300,
'dropout': 0.5, "dropout": 0.5,
'aggr_mode': 'sum', "aggr_mode": "sum",
'residual': True, "residual": True,
'ewnorm': 'none' "ewnorm": "none",
} }
def get_exp_configure(args): def get_exp_configure(args):
if (args['model'] == 'MWE-GCN'): if args["model"] == "MWE-GCN":
return MWE_GCN_proteins return MWE_GCN_proteins
elif (args['model'] == 'MWE-DGCN'): elif args["model"] == "MWE-DGCN":
return MWE_DGCN_proteins return MWE_DGCN_proteins
...@@ -7,20 +7,23 @@ import random ...@@ -7,20 +7,23 @@ import random
import sys import sys
import time import time
import dgl
import dgl.function as fn
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import torch.optim as optim import torch.optim as optim
from dgl.dataloading import MultiLayerFullNeighborSampler, MultiLayerNeighborSampler
from dgl.dataloading import DataLoader
from matplotlib.ticker import AutoMinorLocator, MultipleLocator from matplotlib.ticker import AutoMinorLocator, MultipleLocator
from models import GAT
from ogb.nodeproppred import DglNodePropPredDataset, Evaluator from ogb.nodeproppred import DglNodePropPredDataset, Evaluator
from torch import nn from torch import nn
from models import GAT import dgl
import dgl.function as fn
from dgl.dataloading import (
DataLoader,
MultiLayerFullNeighborSampler,
MultiLayerNeighborSampler,
)
device = None device = None
dataset = "ogbn-proteins" dataset = "ogbn-proteins"
...@@ -43,7 +46,11 @@ def load_data(dataset): ...@@ -43,7 +46,11 @@ def load_data(dataset):
evaluator = Evaluator(name=dataset) evaluator = Evaluator(name=dataset)
splitted_idx = data.get_idx_split() splitted_idx = data.get_idx_split()
train_idx, val_idx, test_idx = splitted_idx["train"], splitted_idx["valid"], splitted_idx["test"] train_idx, val_idx, test_idx = (
splitted_idx["train"],
splitted_idx["valid"],
splitted_idx["test"],
)
graph, labels = data[0] graph, labels = data[0]
graph.ndata["labels"] = labels graph.ndata["labels"] = labels
...@@ -54,11 +61,15 @@ def preprocess(graph, labels, train_idx): ...@@ -54,11 +61,15 @@ def preprocess(graph, labels, train_idx):
global n_node_feats global n_node_feats
# The sum of the weights of adjacent edges is used as node features. # The sum of the weights of adjacent edges is used as node features.
graph.update_all(fn.copy_e("feat", "feat_copy"), fn.sum("feat_copy", "feat")) graph.update_all(
fn.copy_e("feat", "feat_copy"), fn.sum("feat_copy", "feat")
)
n_node_feats = graph.ndata["feat"].shape[-1] n_node_feats = graph.ndata["feat"].shape[-1]
# Only the labels in the training set are used as features, while others are filled with zeros. # Only the labels in the training set are used as features, while others are filled with zeros.
graph.ndata["train_labels_onehot"] = torch.zeros(graph.number_of_nodes(), n_classes) graph.ndata["train_labels_onehot"] = torch.zeros(
graph.number_of_nodes(), n_classes
)
graph.ndata["train_labels_onehot"][train_idx, labels[train_idx, 0]] = 1 graph.ndata["train_labels_onehot"][train_idx, labels[train_idx, 0]] = 1
graph.ndata["deg"] = graph.out_degrees().float().clamp(min=1) graph.ndata["deg"] = graph.out_degrees().float().clamp(min=1)
...@@ -99,7 +110,16 @@ def add_labels(graph, idx): ...@@ -99,7 +110,16 @@ def add_labels(graph, idx):
graph.srcdata["feat"] = torch.cat([feat, train_labels_onehot], dim=-1) graph.srcdata["feat"] = torch.cat([feat, train_labels_onehot], dim=-1)
def train(args, model, dataloader, _labels, _train_idx, criterion, optimizer, _evaluator): def train(
args,
model,
dataloader,
_labels,
_train_idx,
criterion,
optimizer,
_evaluator,
):
model.train() model.train()
loss_sum, total = 0, 0 loss_sum, total = 0, 0
...@@ -109,7 +129,9 @@ def train(args, model, dataloader, _labels, _train_idx, criterion, optimizer, _e ...@@ -109,7 +129,9 @@ def train(args, model, dataloader, _labels, _train_idx, criterion, optimizer, _e
new_train_idx = torch.arange(len(output_nodes), device=device) new_train_idx = torch.arange(len(output_nodes), device=device)
if args.use_labels: if args.use_labels:
train_labels_idx = torch.arange(len(output_nodes), len(input_nodes), device=device) train_labels_idx = torch.arange(
len(output_nodes), len(input_nodes), device=device
)
train_pred_idx = new_train_idx train_pred_idx = new_train_idx
add_labels(subgraphs[0], train_labels_idx) add_labels(subgraphs[0], train_labels_idx)
...@@ -117,7 +139,10 @@ def train(args, model, dataloader, _labels, _train_idx, criterion, optimizer, _e ...@@ -117,7 +139,10 @@ def train(args, model, dataloader, _labels, _train_idx, criterion, optimizer, _e
train_pred_idx = new_train_idx train_pred_idx = new_train_idx
pred = model(subgraphs) pred = model(subgraphs)
loss = criterion(pred[train_pred_idx], subgraphs[-1].dstdata["labels"][train_pred_idx].float()) loss = criterion(
pred[train_pred_idx],
subgraphs[-1].dstdata["labels"][train_pred_idx].float(),
)
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
optimizer.step() optimizer.step()
...@@ -132,7 +157,17 @@ def train(args, model, dataloader, _labels, _train_idx, criterion, optimizer, _e ...@@ -132,7 +157,17 @@ def train(args, model, dataloader, _labels, _train_idx, criterion, optimizer, _e
@torch.no_grad() @torch.no_grad()
def evaluate(args, model, dataloader, labels, train_idx, val_idx, test_idx, criterion, evaluator): def evaluate(
args,
model,
dataloader,
labels,
train_idx,
val_idx,
test_idx,
criterion,
evaluator,
):
model.eval() model.eval()
preds = torch.zeros(labels.shape).to(device) preds = torch.zeros(labels.shape).to(device)
...@@ -170,37 +205,49 @@ def evaluate(args, model, dataloader, labels, train_idx, val_idx, test_idx, crit ...@@ -170,37 +205,49 @@ def evaluate(args, model, dataloader, labels, train_idx, val_idx, test_idx, crit
) )
def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running): def run(
evaluator_wrapper = lambda pred, labels: evaluator.eval({"y_pred": pred, "y_true": labels})["rocauc"] args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running
):
evaluator_wrapper = lambda pred, labels: evaluator.eval(
{"y_pred": pred, "y_true": labels}
)["rocauc"]
train_batch_size = (len(train_idx) + 9) // 10 train_batch_size = (len(train_idx) + 9) // 10
# batch_size = len(train_idx) # batch_size = len(train_idx)
train_sampler = MultiLayerNeighborSampler([32 for _ in range(args.n_layers)]) train_sampler = MultiLayerNeighborSampler(
[32 for _ in range(args.n_layers)]
)
# sampler = MultiLayerFullNeighborSampler(args.n_layers) # sampler = MultiLayerFullNeighborSampler(args.n_layers)
train_dataloader = DataLoader( train_dataloader = DataLoader(
graph.cpu(), graph.cpu(),
train_idx.cpu(), train_idx.cpu(),
train_sampler, train_sampler,
batch_size=train_batch_size, batch_size=train_batch_size,
num_workers=10, num_workers=10,
) )
eval_sampler = MultiLayerNeighborSampler([100 for _ in range(args.n_layers)]) eval_sampler = MultiLayerNeighborSampler(
[100 for _ in range(args.n_layers)]
)
# sampler = MultiLayerFullNeighborSampler(args.n_layers) # sampler = MultiLayerFullNeighborSampler(args.n_layers)
eval_dataloader = DataLoader( eval_dataloader = DataLoader(
graph.cpu(), graph.cpu(),
torch.cat([train_idx.cpu(), val_idx.cpu(), test_idx.cpu()]), torch.cat([train_idx.cpu(), val_idx.cpu(), test_idx.cpu()]),
eval_sampler, eval_sampler,
batch_size=65536, batch_size=65536,
num_workers=10, num_workers=10,
) )
criterion = nn.BCEWithLogitsLoss() criterion = nn.BCEWithLogitsLoss()
model = gen_model(args).to(device) model = gen_model(args).to(device)
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wd) optimizer = optim.AdamW(
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="max", factor=0.75, patience=50, verbose=True) model.parameters(), lr=args.lr, weight_decay=args.wd
)
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode="max", factor=0.75, patience=50, verbose=True
)
total_time = 0 total_time = 0
val_score, best_val_score, final_test_score = 0, 0, 0 val_score, best_val_score, final_test_score = 0, 0, 0
...@@ -212,14 +259,43 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running) ...@@ -212,14 +259,43 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running)
for epoch in range(1, args.n_epochs + 1): for epoch in range(1, args.n_epochs + 1):
tic = time.time() tic = time.time()
loss = train(args, model, train_dataloader, labels, train_idx, criterion, optimizer, evaluator_wrapper) loss = train(
args,
model,
train_dataloader,
labels,
train_idx,
criterion,
optimizer,
evaluator_wrapper,
)
toc = time.time() toc = time.time()
total_time += toc - tic total_time += toc - tic
if epoch == args.n_epochs or epoch % args.eval_every == 0 or epoch % args.log_every == 0: if (
train_score, val_score, test_score, train_loss, val_loss, test_loss, pred = evaluate( epoch == args.n_epochs
args, model, eval_dataloader, labels, train_idx, val_idx, test_idx, criterion, evaluator_wrapper or epoch % args.eval_every == 0
or epoch % args.log_every == 0
):
(
train_score,
val_score,
test_score,
train_loss,
val_loss,
test_loss,
pred,
) = evaluate(
args,
model,
eval_dataloader,
labels,
train_idx,
val_idx,
test_idx,
criterion,
evaluator_wrapper,
) )
if val_score > best_val_score: if val_score > best_val_score:
...@@ -238,15 +314,33 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running) ...@@ -238,15 +314,33 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running)
) )
for l, e in zip( for l, e in zip(
[train_scores, val_scores, test_scores, losses, train_losses, val_losses, test_losses], [
[train_score, val_score, test_score, loss, train_loss, val_loss, test_loss], train_scores,
val_scores,
test_scores,
losses,
train_losses,
val_losses,
test_losses,
],
[
train_score,
val_score,
test_score,
loss,
train_loss,
val_loss,
test_loss,
],
): ):
l.append(e) l.append(e)
lr_scheduler.step(val_score) lr_scheduler.step(val_score)
print("*" * 50) print("*" * 50)
print(f"Best val score: {best_val_score}, Final test score: {final_test_score}") print(
f"Best val score: {best_val_score}, Final test score: {final_test_score}"
)
print("*" * 50) print("*" * 50)
if args.plot: if args.plot:
...@@ -255,8 +349,16 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running) ...@@ -255,8 +349,16 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running)
ax.set_xticks(np.arange(0, args.n_epochs, 100)) ax.set_xticks(np.arange(0, args.n_epochs, 100))
ax.set_yticks(np.linspace(0, 1.0, 101)) ax.set_yticks(np.linspace(0, 1.0, 101))
ax.tick_params(labeltop=True, labelright=True) ax.tick_params(labeltop=True, labelright=True)
for y, label in zip([train_scores, val_scores, test_scores], ["train score", "val score", "test score"]): for y, label in zip(
plt.plot(range(1, args.n_epochs + 1, args.log_every), y, label=label, linewidth=1) [train_scores, val_scores, test_scores],
["train score", "val score", "test score"],
):
plt.plot(
range(1, args.n_epochs + 1, args.log_every),
y,
label=label,
linewidth=1,
)
ax.xaxis.set_major_locator(MultipleLocator(100)) ax.xaxis.set_major_locator(MultipleLocator(100))
ax.xaxis.set_minor_locator(AutoMinorLocator(1)) ax.xaxis.set_minor_locator(AutoMinorLocator(1))
ax.yaxis.set_major_locator(MultipleLocator(0.01)) ax.yaxis.set_major_locator(MultipleLocator(0.01))
...@@ -272,9 +374,15 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running) ...@@ -272,9 +374,15 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running)
ax.set_xticks(np.arange(0, args.n_epochs, 100)) ax.set_xticks(np.arange(0, args.n_epochs, 100))
ax.tick_params(labeltop=True, labelright=True) ax.tick_params(labeltop=True, labelright=True)
for y, label in zip( for y, label in zip(
[losses, train_losses, val_losses, test_losses], ["loss", "train loss", "val loss", "test loss"] [losses, train_losses, val_losses, test_losses],
["loss", "train loss", "val loss", "test loss"],
): ):
plt.plot(range(1, args.n_epochs + 1, args.log_every), y, label=label, linewidth=1) plt.plot(
range(1, args.n_epochs + 1, args.log_every),
y,
label=label,
linewidth=1,
)
ax.xaxis.set_major_locator(MultipleLocator(100)) ax.xaxis.set_major_locator(MultipleLocator(100))
ax.xaxis.set_minor_locator(AutoMinorLocator(1)) ax.xaxis.set_minor_locator(AutoMinorLocator(1))
ax.yaxis.set_major_locator(MultipleLocator(0.1)) ax.yaxis.set_major_locator(MultipleLocator(0.1))
...@@ -294,37 +402,79 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running) ...@@ -294,37 +402,79 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running)
def count_parameters(args): def count_parameters(args):
model = gen_model(args) model = gen_model(args)
return sum([np.prod(p.size()) for p in model.parameters() if p.requires_grad]) return sum(
[np.prod(p.size()) for p in model.parameters() if p.requires_grad]
)
def main(): def main():
global device global device
argparser = argparse.ArgumentParser( argparser = argparse.ArgumentParser(
"GAT implementation on ogbn-proteins", formatter_class=argparse.ArgumentDefaultsHelpFormatter "GAT implementation on ogbn-proteins",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
argparser.add_argument(
"--cpu",
action="store_true",
help="CPU mode. This option overrides '--gpu'.",
) )
argparser.add_argument("--cpu", action="store_true", help="CPU mode. This option overrides '--gpu'.")
argparser.add_argument("--gpu", type=int, default=0, help="GPU device ID") argparser.add_argument("--gpu", type=int, default=0, help="GPU device ID")
argparser.add_argument("--seed", type=int, default=0, help="random seed") argparser.add_argument("--seed", type=int, default=0, help="random seed")
argparser.add_argument("--n-runs", type=int, default=10, help="running times")
argparser.add_argument("--n-epochs", type=int, default=1200, help="number of epochs")
argparser.add_argument( argparser.add_argument(
"--use-labels", action="store_true", help="Use labels in the training set as input features." "--n-runs", type=int, default=10, help="running times"
) )
argparser.add_argument("--no-attn-dst", action="store_true", help="Don't use attn_dst.") argparser.add_argument(
argparser.add_argument("--n-heads", type=int, default=6, help="number of heads") "--n-epochs", type=int, default=1200, help="number of epochs"
argparser.add_argument("--lr", type=float, default=0.01, help="learning rate") )
argparser.add_argument("--n-layers", type=int, default=6, help="number of layers") argparser.add_argument(
argparser.add_argument("--n-hidden", type=int, default=80, help="number of hidden units") "--use-labels",
argparser.add_argument("--dropout", type=float, default=0.25, help="dropout rate") action="store_true",
argparser.add_argument("--input-drop", type=float, default=0.1, help="input drop rate") help="Use labels in the training set as input features.",
argparser.add_argument("--attn-drop", type=float, default=0.0, help="attention dropout rate") )
argparser.add_argument("--edge-drop", type=float, default=0.1, help="edge drop rate") argparser.add_argument(
"--no-attn-dst", action="store_true", help="Don't use attn_dst."
)
argparser.add_argument(
"--n-heads", type=int, default=6, help="number of heads"
)
argparser.add_argument(
"--lr", type=float, default=0.01, help="learning rate"
)
argparser.add_argument(
"--n-layers", type=int, default=6, help="number of layers"
)
argparser.add_argument(
"--n-hidden", type=int, default=80, help="number of hidden units"
)
argparser.add_argument(
"--dropout", type=float, default=0.25, help="dropout rate"
)
argparser.add_argument(
"--input-drop", type=float, default=0.1, help="input drop rate"
)
argparser.add_argument(
"--attn-drop", type=float, default=0.0, help="attention dropout rate"
)
argparser.add_argument(
"--edge-drop", type=float, default=0.1, help="edge drop rate"
)
argparser.add_argument("--wd", type=float, default=0, help="weight decay") argparser.add_argument("--wd", type=float, default=0, help="weight decay")
argparser.add_argument("--eval-every", type=int, default=5, help="evaluate every EVAL_EVERY epochs") argparser.add_argument(
argparser.add_argument("--log-every", type=int, default=5, help="log every LOG_EVERY epochs") "--eval-every",
argparser.add_argument("--plot", action="store_true", help="plot learning curves") type=int,
argparser.add_argument("--save-pred", action="store_true", help="save final predictions") default=5,
help="evaluate every EVAL_EVERY epochs",
)
argparser.add_argument(
"--log-every", type=int, default=5, help="log every LOG_EVERY epochs"
)
argparser.add_argument(
"--plot", action="store_true", help="plot learning curves"
)
argparser.add_argument(
"--save-pred", action="store_true", help="save final predictions"
)
args = argparser.parse_args() args = argparser.parse_args()
if args.cpu: if args.cpu:
...@@ -338,7 +488,9 @@ def main(): ...@@ -338,7 +488,9 @@ def main():
print("Preprocessing") print("Preprocessing")
graph, labels = preprocess(graph, labels, train_idx) graph, labels = preprocess(graph, labels, train_idx)
labels, train_idx, val_idx, test_idx = map(lambda x: x.to(device), (labels, train_idx, val_idx, test_idx)) labels, train_idx, val_idx, test_idx = map(
lambda x: x.to(device), (labels, train_idx, val_idx, test_idx)
)
# run # run
val_scores, test_scores = [], [] val_scores, test_scores = [], []
...@@ -346,7 +498,9 @@ def main(): ...@@ -346,7 +498,9 @@ def main():
for i in range(args.n_runs): for i in range(args.n_runs):
print("Running", i) print("Running", i)
seed(args.seed + i) seed(args.seed + i)
val_score, test_score = run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, i + 1) val_score, test_score = run(
args, graph, labels, train_idx, val_idx, test_idx, evaluator, i + 1
)
val_scores.append(val_score) val_scores.append(val_score)
test_scores.append(test_score) test_scores.append(test_score)
......
import os import os
import time import time
import dgl.function as fn
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -11,9 +10,10 @@ from ogb.nodeproppred.dataset_dgl import DglNodePropPredDataset ...@@ -11,9 +10,10 @@ from ogb.nodeproppred.dataset_dgl import DglNodePropPredDataset
from torch.optim import Adam from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from utils import load_model, set_random_seed from utils import load_model, set_random_seed
import dgl.function as fn
def normalize_edge_weights(graph, device, num_ew_channels): def normalize_edge_weights(graph, device, num_ew_channels):
degs = graph.in_degrees().float() degs = graph.in_degrees().float()
...@@ -24,7 +24,9 @@ def normalize_edge_weights(graph, device, num_ew_channels): ...@@ -24,7 +24,9 @@ def normalize_edge_weights(graph, device, num_ew_channels):
graph.apply_edges(fn.e_div_u("feat", "norm", "feat")) graph.apply_edges(fn.e_div_u("feat", "norm", "feat"))
graph.apply_edges(fn.e_div_v("feat", "norm", "feat")) graph.apply_edges(fn.e_div_v("feat", "norm", "feat"))
for channel in range(num_ew_channels): for channel in range(num_ew_channels):
graph.edata["feat_" + str(channel)] = graph.edata["feat"][:, channel : channel + 1] graph.edata["feat_" + str(channel)] = graph.edata["feat"][
:, channel : channel + 1
]
def run_a_train_epoch(graph, node_idx, model, criterion, optimizer, evaluator): def run_a_train_epoch(graph, node_idx, model, criterion, optimizer, evaluator):
...@@ -50,9 +52,24 @@ def run_an_eval_epoch(graph, splitted_idx, model, evaluator): ...@@ -50,9 +52,24 @@ def run_an_eval_epoch(graph, splitted_idx, model, evaluator):
labels = graph.ndata["labels"].cpu().numpy() labels = graph.ndata["labels"].cpu().numpy()
preds = logits.cpu().detach().numpy() preds = logits.cpu().detach().numpy()
train_score = evaluator.eval({"y_true": labels[splitted_idx["train"]], "y_pred": preds[splitted_idx["train"]]}) train_score = evaluator.eval(
val_score = evaluator.eval({"y_true": labels[splitted_idx["valid"]], "y_pred": preds[splitted_idx["valid"]]}) {
test_score = evaluator.eval({"y_true": labels[splitted_idx["test"]], "y_pred": preds[splitted_idx["test"]]}) "y_true": labels[splitted_idx["train"]],
"y_pred": preds[splitted_idx["train"]],
}
)
val_score = evaluator.eval(
{
"y_true": labels[splitted_idx["valid"]],
"y_pred": preds[splitted_idx["valid"]],
}
)
test_score = evaluator.eval(
{
"y_true": labels[splitted_idx["test"]],
"y_pred": preds[splitted_idx["test"]],
}
)
return train_score["rocauc"], val_score["rocauc"], test_score["rocauc"] return train_score["rocauc"], val_score["rocauc"], test_score["rocauc"]
...@@ -75,12 +92,18 @@ def main(args): ...@@ -75,12 +92,18 @@ def main(args):
elif args["ewnorm"] == "none": elif args["ewnorm"] == "none":
print("Not normalizing edge weights") print("Not normalizing edge weights")
for channel in range(args["num_ew_channels"]): for channel in range(args["num_ew_channels"]):
graph.edata["feat_" + str(channel)] = graph.edata["feat"][:, channel : channel + 1] graph.edata["feat_" + str(channel)] = graph.edata["feat"][
:, channel : channel + 1
]
model = load_model(args).to(args["device"]) model = load_model(args).to(args["device"])
optimizer = Adam(model.parameters(), lr=args["lr"], weight_decay=args["weight_decay"]) optimizer = Adam(
model.parameters(), lr=args["lr"], weight_decay=args["weight_decay"]
)
min_lr = 1e-3 min_lr = 1e-3
scheduler = ReduceLROnPlateau(optimizer, "max", factor=0.7, patience=100, verbose=True, min_lr=min_lr) scheduler = ReduceLROnPlateau(
optimizer, "max", factor=0.7, patience=100, verbose=True, min_lr=min_lr
)
print("scheduler min_lr", min_lr) print("scheduler min_lr", min_lr)
criterion = nn.BCEWithLogitsLoss() criterion = nn.BCEWithLogitsLoss()
...@@ -95,7 +118,9 @@ def main(args): ...@@ -95,7 +118,9 @@ def main(args):
best_val_score = 0.0 best_val_score = 0.0
num_patient_epochs = 0 num_patient_epochs = 0
model_folder = "./saved_models/" model_folder = "./saved_models/"
model_path = model_folder + str(args["exp_name"]) + "_" + str(args["postfix"]) model_path = (
model_folder + str(args["exp_name"]) + "_" + str(args["postfix"])
)
if not os.path.exists(model_folder): if not os.path.exists(model_folder):
os.makedirs(model_folder) os.makedirs(model_folder)
...@@ -104,7 +129,9 @@ def main(args): ...@@ -104,7 +129,9 @@ def main(args):
if epoch >= 3: if epoch >= 3:
t0 = time.time() t0 = time.time()
loss, train_score = run_a_train_epoch(graph, splitted_idx["train"], model, criterion, optimizer, evaluator) loss, train_score = run_a_train_epoch(
graph, splitted_idx["train"], model, criterion, optimizer, evaluator
)
if epoch >= 3: if epoch >= 3:
dur.append(time.time() - t0) dur.append(time.time() - t0)
...@@ -112,7 +139,9 @@ def main(args): ...@@ -112,7 +139,9 @@ def main(args):
else: else:
avg_time = None avg_time = None
train_score, val_score, test_score = run_an_eval_epoch(graph, splitted_idx, model, evaluator) train_score, val_score, test_score = run_an_eval_epoch(
graph, splitted_idx, model, evaluator
)
scheduler.step(val_score) scheduler.step(val_score)
...@@ -127,7 +156,12 @@ def main(args): ...@@ -127,7 +156,12 @@ def main(args):
print( print(
"Epoch {:d}, loss {:.4f}, train score {:.4f}, " "Epoch {:d}, loss {:.4f}, train score {:.4f}, "
"val score {:.4f}, avg time {}, num patient epochs {:d}".format( "val score {:.4f}, avg time {}, num patient epochs {:d}".format(
epoch, loss, train_score, val_score, avg_time, num_patient_epochs epoch,
loss,
train_score,
val_score,
avg_time,
num_patient_epochs,
) )
) )
...@@ -135,7 +169,9 @@ def main(args): ...@@ -135,7 +169,9 @@ def main(args):
break break
model.load_state_dict(torch.load(model_path)) model.load_state_dict(torch.load(model_path))
train_score, val_score, test_score = run_an_eval_epoch(graph, splitted_idx, model, evaluator) train_score, val_score, test_score = run_an_eval_epoch(
graph, splitted_idx, model, evaluator
)
print("Train score {:.4f}".format(train_score)) print("Train score {:.4f}".format(train_score))
print("Valid score {:.4f}".format(val_score)) print("Valid score {:.4f}".format(val_score))
print("Test score {:.4f}".format(test_score)) print("Test score {:.4f}".format(test_score))
...@@ -153,15 +189,34 @@ if __name__ == "__main__": ...@@ -153,15 +189,34 @@ if __name__ == "__main__":
from configure import get_exp_configure from configure import get_exp_configure
parser = argparse.ArgumentParser(description="OGB node property prediction with DGL using full graph training") parser = argparse.ArgumentParser(
description="OGB node property prediction with DGL using full graph training"
)
parser.add_argument( parser.add_argument(
"-m", "--model", type=str, choices=["MWE-GCN", "MWE-DGCN"], default="MWE-DGCN", help="Model to use" "-m",
"--model",
type=str,
choices=["MWE-GCN", "MWE-DGCN"],
default="MWE-DGCN",
help="Model to use",
) )
parser.add_argument("-c", "--cuda", type=str, default="none") parser.add_argument("-c", "--cuda", type=str, default="none")
parser.add_argument("--postfix", type=str, default="", help="a string appended to the file name of the saved model") parser.add_argument(
parser.add_argument("--rand_seed", type=int, default=-1, help="random seed for torch and numpy") "--postfix",
type=str,
default="",
help="a string appended to the file name of the saved model",
)
parser.add_argument(
"--rand_seed",
type=int,
default=-1,
help="random seed for torch and numpy",
)
parser.add_argument("--residual", action="store_true") parser.add_argument("--residual", action="store_true")
parser.add_argument("--ewnorm", type=str, default="none", choices=["none", "both"]) parser.add_argument(
"--ewnorm", type=str, default="none", choices=["none", "both"]
)
args = parser.parse_args().__dict__ args = parser.parse_args().__dict__
# Get experiment configuration # Get experiment configuration
......
...@@ -3,7 +3,6 @@ import random ...@@ -3,7 +3,6 @@ import random
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from models import MWE_DGCN, MWE_GCN from models import MWE_DGCN, MWE_GCN
...@@ -87,4 +86,3 @@ class Logger(object): ...@@ -87,4 +86,3 @@ class Logger(object):
print(f" Final Train: {r.mean():.2f} ± {r.std():.2f}") print(f" Final Train: {r.mean():.2f} ± {r.std():.2f}")
r = best_result[:, 3] r = best_result[:, 3]
print(f" Final Test: {r.mean():.2f} ± {r.std():.2f}") print(f" Final Test: {r.mean():.2f} ± {r.std():.2f}")
import argparse import argparse
import time
import os
import sys
import math import math
import os
import random import random
from tqdm import tqdm import sys
import time
import numpy as np import numpy as np
import torch import torch
from torch.nn import ModuleList, Linear, Conv1d, MaxPool1d, Embedding, BCEWithLogitsLoss
import torch.nn.functional as F import torch.nn.functional as F
from ogb.linkproppred import DglLinkPropPredDataset, Evaluator
from scipy.sparse.csgraph import shortest_path
from torch.nn import (
BCEWithLogitsLoss,
Conv1d,
Embedding,
Linear,
MaxPool1d,
ModuleList,
)
from tqdm import tqdm
import dgl import dgl
from dgl.dataloading import DataLoader, Sampler
from dgl.nn import GraphConv, SortPooling from dgl.nn import GraphConv, SortPooling
from dgl.sampling import global_uniform_negative_sampling from dgl.sampling import global_uniform_negative_sampling
from dgl.dataloading import Sampler, DataLoader
from ogb.linkproppred import DglLinkPropPredDataset, Evaluator
from scipy.sparse.csgraph import shortest_path
class Logger(object): class Logger(object):
...@@ -32,10 +41,10 @@ class Logger(object): ...@@ -32,10 +41,10 @@ class Logger(object):
if run is not None: if run is not None:
result = 100 * torch.tensor(self.results[run]) result = 100 * torch.tensor(self.results[run])
argmax = result[:, 0].argmax().item() argmax = result[:, 0].argmax().item()
print(f'Run {run + 1:02d}:', file=f) print(f"Run {run + 1:02d}:", file=f)
print(f'Highest Valid: {result[:, 0].max():.2f}', file=f) print(f"Highest Valid: {result[:, 0].max():.2f}", file=f)
print(f'Highest Eval Point: {argmax + 1}', file=f) print(f"Highest Eval Point: {argmax + 1}", file=f)
print(f' Final Test: {result[argmax, 1]:.2f}', file=f) print(f" Final Test: {result[argmax, 1]:.2f}", file=f)
else: else:
result = 100 * torch.tensor(self.results) result = 100 * torch.tensor(self.results)
...@@ -47,16 +56,23 @@ class Logger(object): ...@@ -47,16 +56,23 @@ class Logger(object):
best_result = torch.tensor(best_results) best_result = torch.tensor(best_results)
print(f'All runs:', file=f) print(f"All runs:", file=f)
r = best_result[:, 0] r = best_result[:, 0]
print(f'Highest Valid: {r.mean():.2f} ± {r.std():.2f}', file=f) print(f"Highest Valid: {r.mean():.2f} ± {r.std():.2f}", file=f)
r = best_result[:, 1] r = best_result[:, 1]
print(f' Final Test: {r.mean():.2f} ± {r.std():.2f}', file=f) print(f" Final Test: {r.mean():.2f} ± {r.std():.2f}", file=f)
class SealSampler(Sampler): class SealSampler(Sampler):
def __init__(self, g, num_hops=1, sample_ratio=1., directed=False, def __init__(
prefetch_node_feats=None, prefetch_edge_feats=None): self,
g,
num_hops=1,
sample_ratio=1.0,
directed=False,
prefetch_node_feats=None,
prefetch_edge_feats=None,
):
super().__init__() super().__init__()
self.g = g self.g = g
self.num_hops = num_hops self.num_hops = num_hops
...@@ -71,22 +87,29 @@ class SealSampler(Sampler): ...@@ -71,22 +87,29 @@ class SealSampler(Sampler):
idx = list(range(1)) + list(range(2, N)) idx = list(range(1)) + list(range(2, N))
adj_wo_dst = adj[idx, :][:, idx] adj_wo_dst = adj[idx, :][:, idx]
dist2src = shortest_path(adj_wo_dst, directed=False, unweighted=True, indices=0) dist2src = shortest_path(
adj_wo_dst, directed=False, unweighted=True, indices=0
)
dist2src = np.insert(dist2src, 1, 0, axis=0) dist2src = np.insert(dist2src, 1, 0, axis=0)
dist2src = torch.from_numpy(dist2src) dist2src = torch.from_numpy(dist2src)
dist2dst = shortest_path(adj_wo_src, directed=False, unweighted=True, indices=0) dist2dst = shortest_path(
adj_wo_src, directed=False, unweighted=True, indices=0
)
dist2dst = np.insert(dist2dst, 0, 0, axis=0) dist2dst = np.insert(dist2dst, 0, 0, axis=0)
dist2dst = torch.from_numpy(dist2dst) dist2dst = torch.from_numpy(dist2dst)
dist = dist2src + dist2dst dist = dist2src + dist2dst
dist_over_2, dist_mod_2 = torch.div(dist, 2, rounding_mode='floor'), dist % 2 dist_over_2, dist_mod_2 = (
torch.div(dist, 2, rounding_mode="floor"),
dist % 2,
)
z = 1 + torch.min(dist2src, dist2dst) z = 1 + torch.min(dist2src, dist2dst)
z += dist_over_2 * (dist_over_2 + dist_mod_2 - 1) z += dist_over_2 * (dist_over_2 + dist_mod_2 - 1)
z[0: 2] = 1. z[0:2] = 1.0
# shortest path may include inf values # shortest path may include inf values
z[torch.isnan(z)] = 0. z[torch.isnan(z)] = 0.0
return z.to(torch.long) return z.to(torch.long)
...@@ -107,9 +130,12 @@ class SealSampler(Sampler): ...@@ -107,9 +130,12 @@ class SealSampler(Sampler):
fringe = np.union1d(in_neighbors, out_neighbors) fringe = np.union1d(in_neighbors, out_neighbors)
fringe = np.setdiff1d(fringe, visited) fringe = np.setdiff1d(fringe, visited)
visited = np.union1d(visited, fringe) visited = np.union1d(visited, fringe)
if self.sample_ratio < 1.: if self.sample_ratio < 1.0:
fringe = np.random.choice(fringe, fringe = np.random.choice(
int(self.sample_ratio * len(fringe)), replace=False) fringe,
int(self.sample_ratio * len(fringe)),
replace=False,
)
if len(fringe) == 0: if len(fringe) == 0:
break break
nodes = np.union1d(nodes, fringe) nodes = np.union1d(nodes, fringe)
...@@ -117,26 +143,34 @@ class SealSampler(Sampler): ...@@ -117,26 +143,34 @@ class SealSampler(Sampler):
# remove edges to predict # remove edges to predict
edges_to_remove = [ edges_to_remove = [
subg.edge_ids(s, t) for s, t in [(0, 1), (1, 0)] if subg.has_edges_between(s, t)] subg.edge_ids(s, t)
for s, t in [(0, 1), (1, 0)]
if subg.has_edges_between(s, t)
]
subg.remove_edges(edges_to_remove) subg.remove_edges(edges_to_remove)
# add double radius node labeling # add double radius node labeling
subg.ndata['z'] = self._double_radius_node_labeling(subg.adj(scipy_fmt='csr')) subg.ndata["z"] = self._double_radius_node_labeling(
subg.adj(scipy_fmt="csr")
)
subg_aug = subg.add_self_loop() subg_aug = subg.add_self_loop()
if 'weight' in subg.edata: if "weight" in subg.edata:
subg_aug.edata['weight'][subg.num_edges():] = torch.ones( subg_aug.edata["weight"][subg.num_edges() :] = torch.ones(
subg_aug.num_edges() - subg.num_edges()) subg_aug.num_edges() - subg.num_edges()
)
subgraphs.append(subg_aug) subgraphs.append(subg_aug)
subgraphs = dgl.batch(subgraphs) subgraphs = dgl.batch(subgraphs)
dgl.set_src_lazy_features(subg_aug, self.prefetch_node_feats) dgl.set_src_lazy_features(subg_aug, self.prefetch_node_feats)
dgl.set_edge_lazy_features(subg_aug, self.prefetch_edge_feats) dgl.set_edge_lazy_features(subg_aug, self.prefetch_edge_feats)
return subgraphs, aug_g.edata['y'][seed_edges] return subgraphs, aug_g.edata["y"][seed_edges]
# An end-to-end deep learning architecture for graph classification, AAAI-18. # An end-to-end deep learning architecture for graph classification, AAAI-18.
class DGCNN(torch.nn.Module): class DGCNN(torch.nn.Module):
def __init__(self, hidden_channels, num_layers, k, GNN=GraphConv, feature_dim=0): def __init__(
self, hidden_channels, num_layers, k, GNN=GraphConv, feature_dim=0
):
super(DGCNN, self).__init__() super(DGCNN, self).__init__()
self.feature_dim = feature_dim self.feature_dim = feature_dim
self.k = k self.k = k
...@@ -149,18 +183,18 @@ class DGCNN(torch.nn.Module): ...@@ -149,18 +183,18 @@ class DGCNN(torch.nn.Module):
initial_channels = hidden_channels + self.feature_dim initial_channels = hidden_channels + self.feature_dim
self.convs.append(GNN(initial_channels, hidden_channels)) self.convs.append(GNN(initial_channels, hidden_channels))
for _ in range(0, num_layers-1): for _ in range(0, num_layers - 1):
self.convs.append(GNN(hidden_channels, hidden_channels)) self.convs.append(GNN(hidden_channels, hidden_channels))
self.convs.append(GNN(hidden_channels, 1)) self.convs.append(GNN(hidden_channels, 1))
conv1d_channels = [16, 32] conv1d_channels = [16, 32]
total_latent_dim = hidden_channels * num_layers + 1 total_latent_dim = hidden_channels * num_layers + 1
conv1d_kws = [total_latent_dim, 5] conv1d_kws = [total_latent_dim, 5]
self.conv1 = Conv1d(1, conv1d_channels[0], conv1d_kws[0], self.conv1 = Conv1d(1, conv1d_channels[0], conv1d_kws[0], conv1d_kws[0])
conv1d_kws[0])
self.maxpool1d = MaxPool1d(2, 2) self.maxpool1d = MaxPool1d(2, 2)
self.conv2 = Conv1d(conv1d_channels[0], conv1d_channels[1], self.conv2 = Conv1d(
conv1d_kws[1], 1) conv1d_channels[0], conv1d_channels[1], conv1d_kws[1], 1
)
dense_dim = int((self.k - 2) / 2 + 1) dense_dim = int((self.k - 2) / 2 + 1)
dense_dim = (dense_dim - conv1d_kws[1] + 1) * conv1d_channels[1] dense_dim = (dense_dim - conv1d_kws[1] + 1) * conv1d_channels[1]
self.lin1 = Linear(dense_dim, 128) self.lin1 = Linear(dense_dim, 128)
...@@ -196,33 +230,35 @@ class DGCNN(torch.nn.Module): ...@@ -196,33 +230,35 @@ class DGCNN(torch.nn.Module):
def get_pos_neg_edges(split, split_edge, g, percent=100): def get_pos_neg_edges(split, split_edge, g, percent=100):
pos_edge = split_edge[split]['edge'] pos_edge = split_edge[split]["edge"]
if split == 'train': if split == "train":
neg_edge = torch.stack(global_uniform_negative_sampling( neg_edge = torch.stack(
g, num_samples=pos_edge.size(0), global_uniform_negative_sampling(
exclude_self_loops=True g, num_samples=pos_edge.size(0), exclude_self_loops=True
), dim=1) ),
dim=1,
)
else: else:
neg_edge = split_edge[split]['edge_neg'] neg_edge = split_edge[split]["edge_neg"]
# sampling according to the percent param # sampling according to the percent param
np.random.seed(123) np.random.seed(123)
# pos sampling # pos sampling
num_pos = pos_edge.size(0) num_pos = pos_edge.size(0)
perm = np.random.permutation(num_pos) perm = np.random.permutation(num_pos)
perm = perm[:int(percent / 100 * num_pos)] perm = perm[: int(percent / 100 * num_pos)]
pos_edge = pos_edge[perm] pos_edge = pos_edge[perm]
# neg sampling # neg sampling
if neg_edge.dim() > 2: # [Np, Nn, 2] if neg_edge.dim() > 2: # [Np, Nn, 2]
neg_edge = neg_edge[perm].view(-1, 2) neg_edge = neg_edge[perm].view(-1, 2)
else: else:
np.random.seed(123) np.random.seed(123)
num_neg = neg_edge.size(0) num_neg = neg_edge.size(0)
perm = np.random.permutation(num_neg) perm = np.random.permutation(num_neg)
perm = perm[:int(percent / 100 * num_neg)] perm = perm[: int(percent / 100 * num_neg)]
neg_edge = neg_edge[perm] neg_edge = neg_edge[perm]
return pos_edge, neg_edge # ([2, Np], [2, Nn]) -> ([Np, 2], [Nn, 2]) return pos_edge, neg_edge # ([2, Np], [2, Nn]) -> ([Np, 2], [Nn, 2])
def train(): def train():
...@@ -233,8 +269,12 @@ def train(): ...@@ -233,8 +269,12 @@ def train():
pbar = tqdm(train_loader, ncols=70) pbar = tqdm(train_loader, ncols=70)
for gs, y in pbar: for gs, y in pbar:
optimizer.zero_grad() optimizer.zero_grad()
logits = model(gs, gs.ndata['z'], gs.ndata.get('feat', None), logits = model(
edge_weight=gs.edata.get('weight', None)) gs,
gs.ndata["z"],
gs.ndata.get("feat", None),
edge_weight=gs.edata.get("weight", None),
)
loss = loss_fnt(logits.view(-1), y.to(torch.float)) loss = loss_fnt(logits.view(-1), y.to(torch.float))
loss.backward() loss.backward()
optimizer.step() optimizer.step()
...@@ -250,28 +290,40 @@ def test(): ...@@ -250,28 +290,40 @@ def test():
y_pred, y_true = [], [] y_pred, y_true = [], []
for gs, y in tqdm(val_loader, ncols=70): for gs, y in tqdm(val_loader, ncols=70):
logits = model(gs, gs.ndata['z'], gs.ndata.get('feat', None), logits = model(
edge_weight=gs.edata.get('weight', None)) gs,
gs.ndata["z"],
gs.ndata.get("feat", None),
edge_weight=gs.edata.get("weight", None),
)
y_pred.append(logits.view(-1).cpu()) y_pred.append(logits.view(-1).cpu())
y_true.append(y.view(-1).cpu().to(torch.float)) y_true.append(y.view(-1).cpu().to(torch.float))
val_pred, val_true = torch.cat(y_pred), torch.cat(y_true) val_pred, val_true = torch.cat(y_pred), torch.cat(y_true)
pos_val_pred = val_pred[val_true==1] pos_val_pred = val_pred[val_true == 1]
neg_val_pred = val_pred[val_true==0] neg_val_pred = val_pred[val_true == 0]
y_pred, y_true = [], [] y_pred, y_true = [], []
for gs, y in tqdm(test_loader, ncols=70): for gs, y in tqdm(test_loader, ncols=70):
logits = model(gs, gs.ndata['z'], gs.ndata.get('feat', None), logits = model(
edge_weight=gs.edata.get('weight', None)) gs,
gs.ndata["z"],
gs.ndata.get("feat", None),
edge_weight=gs.edata.get("weight", None),
)
y_pred.append(logits.view(-1).cpu()) y_pred.append(logits.view(-1).cpu())
y_true.append(y.view(-1).cpu().to(torch.float)) y_true.append(y.view(-1).cpu().to(torch.float))
test_pred, test_true = torch.cat(y_pred), torch.cat(y_true) test_pred, test_true = torch.cat(y_pred), torch.cat(y_true)
pos_test_pred = test_pred[test_true==1] pos_test_pred = test_pred[test_true == 1]
neg_test_pred = test_pred[test_true==0] neg_test_pred = test_pred[test_true == 0]
if args.eval_metric == 'hits': if args.eval_metric == "hits":
results = evaluate_hits(pos_val_pred, neg_val_pred, pos_test_pred, neg_test_pred) results = evaluate_hits(
elif args.eval_metric == 'mrr': pos_val_pred, neg_val_pred, pos_test_pred, neg_test_pred
results = evaluate_mrr(pos_val_pred, neg_val_pred, pos_test_pred, neg_test_pred) )
elif args.eval_metric == "mrr":
results = evaluate_mrr(
pos_val_pred, neg_val_pred, pos_test_pred, neg_test_pred
)
return results return results
...@@ -280,184 +332,254 @@ def evaluate_hits(pos_val_pred, neg_val_pred, pos_test_pred, neg_test_pred): ...@@ -280,184 +332,254 @@ def evaluate_hits(pos_val_pred, neg_val_pred, pos_test_pred, neg_test_pred):
results = {} results = {}
for K in [20, 50, 100]: for K in [20, 50, 100]:
evaluator.K = K evaluator.K = K
valid_hits = evaluator.eval({ valid_hits = evaluator.eval(
'y_pred_pos': pos_val_pred, {
'y_pred_neg': neg_val_pred, "y_pred_pos": pos_val_pred,
})[f'hits@{K}'] "y_pred_neg": neg_val_pred,
test_hits = evaluator.eval({ }
'y_pred_pos': pos_test_pred, )[f"hits@{K}"]
'y_pred_neg': neg_test_pred, test_hits = evaluator.eval(
})[f'hits@{K}'] {
"y_pred_pos": pos_test_pred,
results[f'Hits@{K}'] = (valid_hits, test_hits) "y_pred_neg": neg_test_pred,
}
)[f"hits@{K}"]
results[f"Hits@{K}"] = (valid_hits, test_hits)
return results return results
def evaluate_mrr(pos_val_pred, neg_val_pred, pos_test_pred, neg_test_pred): def evaluate_mrr(pos_val_pred, neg_val_pred, pos_test_pred, neg_test_pred):
print(pos_val_pred.size(), neg_val_pred.size(), pos_test_pred.size(), neg_test_pred.size()) print(
pos_val_pred.size(),
neg_val_pred.size(),
pos_test_pred.size(),
neg_test_pred.size(),
)
neg_val_pred = neg_val_pred.view(pos_val_pred.shape[0], -1) neg_val_pred = neg_val_pred.view(pos_val_pred.shape[0], -1)
neg_test_pred = neg_test_pred.view(pos_test_pred.shape[0], -1) neg_test_pred = neg_test_pred.view(pos_test_pred.shape[0], -1)
results = {} results = {}
valid_mrr = evaluator.eval({ valid_mrr = (
'y_pred_pos': pos_val_pred, evaluator.eval(
'y_pred_neg': neg_val_pred, {
})['mrr_list'].mean().item() "y_pred_pos": pos_val_pred,
"y_pred_neg": neg_val_pred,
test_mrr = evaluator.eval({ }
'y_pred_pos': pos_test_pred, )["mrr_list"]
'y_pred_neg': neg_test_pred, .mean()
})['mrr_list'].mean().item() .item()
)
results['MRR'] = (valid_mrr, test_mrr)
test_mrr = (
evaluator.eval(
{
"y_pred_pos": pos_test_pred,
"y_pred_neg": neg_test_pred,
}
)["mrr_list"]
.mean()
.item()
)
results["MRR"] = (valid_mrr, test_mrr)
return results return results
if __name__ == '__main__': if __name__ == "__main__":
# Data settings # Data settings
parser = argparse.ArgumentParser(description='OGBL (SEAL)') parser = argparse.ArgumentParser(description="OGBL (SEAL)")
parser.add_argument('--dataset', type=str, default='ogbl-collab') parser.add_argument("--dataset", type=str, default="ogbl-collab")
# GNN settings # GNN settings
parser.add_argument('--sortpool_k', type=float, default=0.6) parser.add_argument("--sortpool_k", type=float, default=0.6)
parser.add_argument('--num_layers', type=int, default=3) parser.add_argument("--num_layers", type=int, default=3)
parser.add_argument('--hidden_channels', type=int, default=32) parser.add_argument("--hidden_channels", type=int, default=32)
parser.add_argument('--batch_size', type=int, default=32) parser.add_argument("--batch_size", type=int, default=32)
# Subgraph extraction settings # Subgraph extraction settings
parser.add_argument('--ratio_per_hop', type=float, default=1.0) parser.add_argument("--ratio_per_hop", type=float, default=1.0)
parser.add_argument('--use_feature', action='store_true', parser.add_argument(
help="whether to use raw node features as GNN input") "--use_feature",
parser.add_argument('--use_edge_weight', action='store_true', action="store_true",
help="whether to consider edge weight in GNN") help="whether to use raw node features as GNN input",
)
parser.add_argument(
"--use_edge_weight",
action="store_true",
help="whether to consider edge weight in GNN",
)
# Training settings # Training settings
parser.add_argument('--lr', type=float, default=0.0001) parser.add_argument("--lr", type=float, default=0.0001)
parser.add_argument('--epochs', type=int, default=50) parser.add_argument("--epochs", type=int, default=50)
parser.add_argument('--runs', type=int, default=10) parser.add_argument("--runs", type=int, default=10)
parser.add_argument('--train_percent', type=float, default=100) parser.add_argument("--train_percent", type=float, default=100)
parser.add_argument('--val_percent', type=float, default=100) parser.add_argument("--val_percent", type=float, default=100)
parser.add_argument('--test_percent', type=float, default=100) parser.add_argument("--test_percent", type=float, default=100)
parser.add_argument('--num_workers', type=int, default=8, parser.add_argument(
help="number of workers for dynamic dataloaders") "--num_workers",
type=int,
default=8,
help="number of workers for dynamic dataloaders",
)
# Testing settings # Testing settings
parser.add_argument('--use_valedges_as_input', action='store_true') parser.add_argument("--use_valedges_as_input", action="store_true")
parser.add_argument('--eval_steps', type=int, default=1) parser.add_argument("--eval_steps", type=int, default=1)
args = parser.parse_args() args = parser.parse_args()
data_appendix = '_rph{}'.format(''.join(str(args.ratio_per_hop).split('.'))) data_appendix = "_rph{}".format("".join(str(args.ratio_per_hop).split(".")))
if args.use_valedges_as_input: if args.use_valedges_as_input:
data_appendix += '_uvai' data_appendix += "_uvai"
args.res_dir = os.path.join('results/{}_{}'.format(args.dataset, args.res_dir = os.path.join(
time.strftime("%Y%m%d%H%M%S"))) "results/{}_{}".format(args.dataset, time.strftime("%Y%m%d%H%M%S"))
print('Results will be saved in ' + args.res_dir) )
print("Results will be saved in " + args.res_dir)
if not os.path.exists(args.res_dir): if not os.path.exists(args.res_dir):
os.makedirs(args.res_dir) os.makedirs(args.res_dir)
log_file = os.path.join(args.res_dir, 'log.txt') log_file = os.path.join(args.res_dir, "log.txt")
# Save command line input. # Save command line input.
cmd_input = 'python ' + ' '.join(sys.argv) + '\n' cmd_input = "python " + " ".join(sys.argv) + "\n"
with open(os.path.join(args.res_dir, 'cmd_input.txt'), 'a') as f: with open(os.path.join(args.res_dir, "cmd_input.txt"), "a") as f:
f.write(cmd_input) f.write(cmd_input)
print('Command line input: ' + cmd_input + ' is saved.') print("Command line input: " + cmd_input + " is saved.")
with open(log_file, 'a') as f: with open(log_file, "a") as f:
f.write('\n' + cmd_input) f.write("\n" + cmd_input)
dataset = DglLinkPropPredDataset(name=args.dataset) dataset = DglLinkPropPredDataset(name=args.dataset)
split_edge = dataset.get_edge_split() split_edge = dataset.get_edge_split()
graph = dataset[0] graph = dataset[0]
# re-format the data of citation2 # re-format the data of citation2
if args.dataset == 'ogbl-citation2': if args.dataset == "ogbl-citation2":
for k in ['train', 'valid', 'test']: for k in ["train", "valid", "test"]:
src = split_edge[k]['source_node'] src = split_edge[k]["source_node"]
tgt = split_edge[k]['target_node'] tgt = split_edge[k]["target_node"]
split_edge[k]['edge'] = torch.stack([src, tgt], dim=1) split_edge[k]["edge"] = torch.stack([src, tgt], dim=1)
if k != 'train': if k != "train":
tgt_neg = split_edge[k]['target_node_neg'] tgt_neg = split_edge[k]["target_node_neg"]
split_edge[k]['edge_neg'] = torch.stack([ split_edge[k]["edge_neg"] = torch.stack(
src[:, None].repeat(1, tgt_neg.size(1)), [src[:, None].repeat(1, tgt_neg.size(1)), tgt_neg], dim=-1
tgt_neg ) # [Ns, Nt, 2]
], dim=-1) # [Ns, Nt, 2]
# reconstruct the graph for ogbl-collab data for validation edge augmentation and coalesce # reconstruct the graph for ogbl-collab data for validation edge augmentation and coalesce
if args.dataset == 'ogbl-collab': if args.dataset == "ogbl-collab":
if args.use_valedges_as_input: if args.use_valedges_as_input:
val_edges = split_edge['valid']['edge'] val_edges = split_edge["valid"]["edge"]
row, col = val_edges.t() row, col = val_edges.t()
# float edata for to_simple transform # float edata for to_simple transform
graph.edata.pop('year') graph.edata.pop("year")
graph.edata['weight'] = graph.edata['weight'].to(torch.float) graph.edata["weight"] = graph.edata["weight"].to(torch.float)
val_weights = torch.ones(size=(val_edges.size(0), 1)) val_weights = torch.ones(size=(val_edges.size(0), 1))
graph.add_edges(torch.cat([row, col]), torch.cat([col, row]), {'weight': val_weights}) graph.add_edges(
graph = graph.to_simple(copy_edata=True, aggregator='sum') torch.cat([row, col]),
torch.cat([col, row]),
if not args.use_edge_weight and 'weight' in graph.edata: {"weight": val_weights},
graph.edata.pop('weight') )
if not args.use_feature and 'feat' in graph.ndata: graph = graph.to_simple(copy_edata=True, aggregator="sum")
graph.ndata.pop('feat')
if not args.use_edge_weight and "weight" in graph.edata:
if args.dataset.startswith('ogbl-citation'): graph.edata.pop("weight")
args.eval_metric = 'mrr' if not args.use_feature and "feat" in graph.ndata:
graph.ndata.pop("feat")
if args.dataset.startswith("ogbl-citation"):
args.eval_metric = "mrr"
directed = True directed = True
else: else:
args.eval_metric = 'hits' args.eval_metric = "hits"
directed = False directed = False
evaluator = Evaluator(name=args.dataset) evaluator = Evaluator(name=args.dataset)
if args.eval_metric == 'hits': if args.eval_metric == "hits":
loggers = { loggers = {
'Hits@20': Logger(args.runs, args), "Hits@20": Logger(args.runs, args),
'Hits@50': Logger(args.runs, args), "Hits@50": Logger(args.runs, args),
'Hits@100': Logger(args.runs, args), "Hits@100": Logger(args.runs, args),
} }
elif args.eval_metric == 'mrr': elif args.eval_metric == "mrr":
loggers = { loggers = {
'MRR': Logger(args.runs, args), "MRR": Logger(args.runs, args),
} }
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
path = dataset.root + '_seal{}'.format(data_appendix) path = dataset.root + "_seal{}".format(data_appendix)
loaders = [] loaders = []
prefetch_node_feats = ['feat'] if 'feat' in graph.ndata else None prefetch_node_feats = ["feat"] if "feat" in graph.ndata else None
prefetch_edge_feats = ['weight'] if 'weight' in graph.edata else None prefetch_edge_feats = ["weight"] if "weight" in graph.edata else None
train_edge, train_edge_neg = get_pos_neg_edges('train', split_edge, graph, args.train_percent) train_edge, train_edge_neg = get_pos_neg_edges(
val_edge, val_edge_neg = get_pos_neg_edges('valid', split_edge, graph, args.val_percent) "train", split_edge, graph, args.train_percent
test_edge, test_edge_neg = get_pos_neg_edges('test', split_edge, graph, args.test_percent) )
val_edge, val_edge_neg = get_pos_neg_edges(
"valid", split_edge, graph, args.val_percent
)
test_edge, test_edge_neg = get_pos_neg_edges(
"test", split_edge, graph, args.test_percent
)
# create an augmented graph for sampling # create an augmented graph for sampling
aug_g = dgl.graph(graph.edges()) aug_g = dgl.graph(graph.edges())
aug_g.edata['y'] = torch.ones(aug_g.num_edges()) aug_g.edata["y"] = torch.ones(aug_g.num_edges())
aug_edges = torch.cat([val_edge, test_edge, train_edge_neg, val_edge_neg, test_edge_neg]) aug_edges = torch.cat(
aug_labels = torch.cat([ [val_edge, test_edge, train_edge_neg, val_edge_neg, test_edge_neg]
torch.ones(len(val_edge) + len(test_edge)), )
torch.zeros(len(train_edge_neg) + len(val_edge_neg) + len(test_edge_neg)) aug_labels = torch.cat(
]) [
aug_g.add_edges(aug_edges[:, 0], aug_edges[:, 1], {'y': aug_labels}) torch.ones(len(val_edge) + len(test_edge)),
torch.zeros(
len(train_edge_neg) + len(val_edge_neg) + len(test_edge_neg)
),
]
)
aug_g.add_edges(aug_edges[:, 0], aug_edges[:, 1], {"y": aug_labels})
# eids for sampling # eids for sampling
split_len = [graph.num_edges()] + \ split_len = [graph.num_edges()] + list(
list(map(len, [val_edge, test_edge, train_edge_neg, val_edge_neg, test_edge_neg])) map(
train_eids = torch.cat([ len,
graph.edge_ids(train_edge[:, 0], train_edge[:, 1]), [val_edge, test_edge, train_edge_neg, val_edge_neg, test_edge_neg],
torch.arange(sum(split_len[:3]), sum(split_len[:4])) )
]) )
val_eids = torch.cat([ train_eids = torch.cat(
torch.arange(sum(split_len[:1]), sum(split_len[:2])), [
torch.arange(sum(split_len[:4]), sum(split_len[:5])) graph.edge_ids(train_edge[:, 0], train_edge[:, 1]),
]) torch.arange(sum(split_len[:3]), sum(split_len[:4])),
test_eids = torch.cat([ ]
torch.arange(sum(split_len[:2]), sum(split_len[:3])), )
torch.arange(sum(split_len[:5]), sum(split_len[:6])) val_eids = torch.cat(
]) [
sampler = SealSampler(graph, 1, args.ratio_per_hop, directed, torch.arange(sum(split_len[:1]), sum(split_len[:2])),
prefetch_node_feats, prefetch_edge_feats) torch.arange(sum(split_len[:4]), sum(split_len[:5])),
]
)
test_eids = torch.cat(
[
torch.arange(sum(split_len[:2]), sum(split_len[:3])),
torch.arange(sum(split_len[:5]), sum(split_len[:6])),
]
)
sampler = SealSampler(
graph,
1,
args.ratio_per_hop,
directed,
prefetch_node_feats,
prefetch_edge_feats,
)
# force to be dynamic for consistent dataloading # force to be dynamic for consistent dataloading
for split, shuffle, eids in zip( for split, shuffle, eids in zip(
['train', 'valid', 'test'], ["train", "valid", "test"],
[True, False, False], [True, False, False],
[train_eids, val_eids, test_eids] [train_eids, val_eids, test_eids],
): ):
data_loader = DataLoader(aug_g, eids, sampler, shuffle=shuffle, device=device, data_loader = DataLoader(
batch_size=args.batch_size, num_workers=args.num_workers) aug_g,
eids,
sampler,
shuffle=shuffle,
device=device,
batch_size=args.batch_size,
num_workers=args.num_workers,
)
loaders.append(data_loader) loaders.append(data_loader)
train_loader, val_loader, test_loader = loaders train_loader, val_loader, test_loader = loaders
...@@ -474,16 +596,20 @@ if __name__ == '__main__': ...@@ -474,16 +596,20 @@ if __name__ == '__main__':
k = max(k, 10) k = max(k, 10)
for run in range(args.runs): for run in range(args.runs):
model = DGCNN(args.hidden_channels, args.num_layers, k, model = DGCNN(
feature_dim=graph.ndata['feat'].size(1) if args.use_feature else 0).to(device) args.hidden_channels,
args.num_layers,
k,
feature_dim=graph.ndata["feat"].size(1) if args.use_feature else 0,
).to(device)
parameters = list(model.parameters()) parameters = list(model.parameters())
optimizer = torch.optim.Adam(params=parameters, lr=args.lr) optimizer = torch.optim.Adam(params=parameters, lr=args.lr)
total_params = sum(p.numel() for param in parameters for p in param) total_params = sum(p.numel() for param in parameters for p in param)
print(f'Total number of parameters is {total_params}') print(f"Total number of parameters is {total_params}")
print(f'SortPooling k is set to {k}') print(f"SortPooling k is set to {k}")
with open(log_file, 'a') as f: with open(log_file, "a") as f:
print(f'Total number of parameters is {total_params}', file=f) print(f"Total number of parameters is {total_params}", file=f)
print(f'SortPooling k is set to {k}', file=f) print(f"SortPooling k is set to {k}", file=f)
start_epoch = 1 start_epoch = 1
# Training starts # Training starts
...@@ -496,35 +622,41 @@ if __name__ == '__main__': ...@@ -496,35 +622,41 @@ if __name__ == '__main__':
loggers[key].add_result(run, result) loggers[key].add_result(run, result)
model_name = os.path.join( model_name = os.path.join(
args.res_dir, 'run{}_model_checkpoint{}.pth'.format(run+1, epoch)) args.res_dir,
"run{}_model_checkpoint{}.pth".format(run + 1, epoch),
)
optimizer_name = os.path.join( optimizer_name = os.path.join(
args.res_dir, 'run{}_optimizer_checkpoint{}.pth'.format(run+1, epoch)) args.res_dir,
"run{}_optimizer_checkpoint{}.pth".format(run + 1, epoch),
)
torch.save(model.state_dict(), model_name) torch.save(model.state_dict(), model_name)
torch.save(optimizer.state_dict(), optimizer_name) torch.save(optimizer.state_dict(), optimizer_name)
for key, result in results.items(): for key, result in results.items():
valid_res, test_res = result valid_res, test_res = result
to_print = (f'Run: {run + 1:02d}, Epoch: {epoch:02d}, ' + to_print = (
f'Loss: {loss:.4f}, Valid: {100 * valid_res:.2f}%, ' + f"Run: {run + 1:02d}, Epoch: {epoch:02d}, "
f'Test: {100 * test_res:.2f}%') + f"Loss: {loss:.4f}, Valid: {100 * valid_res:.2f}%, "
+ f"Test: {100 * test_res:.2f}%"
)
print(key) print(key)
print(to_print) print(to_print)
with open(log_file, 'a') as f: with open(log_file, "a") as f:
print(key, file=f) print(key, file=f)
print(to_print, file=f) print(to_print, file=f)
for key in loggers.keys(): for key in loggers.keys():
print(key) print(key)
loggers[key].print_statistics(run) loggers[key].print_statistics(run)
with open(log_file, 'a') as f: with open(log_file, "a") as f:
print(key, file=f) print(key, file=f)
loggers[key].print_statistics(run, f=f) loggers[key].print_statistics(run, f=f)
for key in loggers.keys(): for key in loggers.keys():
print(key) print(key)
loggers[key].print_statistics() loggers[key].print_statistics()
with open(log_file, 'a') as f: with open(log_file, "a") as f:
print(key, file=f) print(key, file=f)
loggers[key].print_statistics(f=f) loggers[key].print_statistics(f=f)
print(f'Total number of parameters is {total_params}') print(f"Total number of parameters is {total_params}")
print(f'Results are saved in {args.res_dir}') print(f"Results are saved in {args.res_dir}")
import torch
import numpy as np import numpy as np
import torch
from ogb.nodeproppred import DglNodePropPredDataset, Evaluator
import dgl import dgl
import dgl.function as fn import dgl.function as fn
from ogb.nodeproppred import DglNodePropPredDataset, Evaluator
def get_ogb_evaluator(dataset): def get_ogb_evaluator(dataset):
...@@ -10,10 +11,12 @@ def get_ogb_evaluator(dataset): ...@@ -10,10 +11,12 @@ def get_ogb_evaluator(dataset):
Get evaluator from Open Graph Benchmark based on dataset Get evaluator from Open Graph Benchmark based on dataset
""" """
evaluator = Evaluator(name=dataset) evaluator = Evaluator(name=dataset)
return lambda preds, labels: evaluator.eval({ return lambda preds, labels: evaluator.eval(
"y_true": labels.view(-1, 1), {
"y_pred": preds.view(-1, 1), "y_true": labels.view(-1, 1),
})["acc"] "y_pred": preds.view(-1, 1),
}
)["acc"]
def convert_mag_to_homograph(g, device): def convert_mag_to_homograph(g, device):
...@@ -25,11 +28,13 @@ def convert_mag_to_homograph(g, device): ...@@ -25,11 +28,13 @@ def convert_mag_to_homograph(g, device):
src_writes, dst_writes = g.all_edges(etype="writes") src_writes, dst_writes = g.all_edges(etype="writes")
src_topic, dst_topic = g.all_edges(etype="has_topic") src_topic, dst_topic = g.all_edges(etype="has_topic")
src_aff, dst_aff = g.all_edges(etype="affiliated_with") src_aff, dst_aff = g.all_edges(etype="affiliated_with")
new_g = dgl.heterograph({ new_g = dgl.heterograph(
("paper", "written", "author"): (dst_writes, src_writes), {
("paper", "has_topic", "field"): (src_topic, dst_topic), ("paper", "written", "author"): (dst_writes, src_writes),
("author", "aff", "inst"): (src_aff, dst_aff) ("paper", "has_topic", "field"): (src_topic, dst_topic),
}) ("author", "aff", "inst"): (src_aff, dst_aff),
}
)
new_g = new_g.to(device) new_g = new_g.to(device)
new_g.nodes["paper"].data["feat"] = g.nodes["paper"].data["feat"] new_g.nodes["paper"].data["feat"] = g.nodes["paper"].data["feat"]
new_g["written"].update_all(fn.copy_u("feat", "m"), fn.mean("m", "feat")) new_g["written"].update_all(fn.copy_u("feat", "m"), fn.mean("m", "feat"))
...@@ -65,7 +70,7 @@ def load_dataset(name, device): ...@@ -65,7 +70,7 @@ def load_dataset(name, device):
if name == "ogbn-arxiv": if name == "ogbn-arxiv":
g = dgl.add_reverse_edges(g, copy_ndata=True) g = dgl.add_reverse_edges(g, copy_ndata=True)
g = dgl.add_self_loop(g) g = dgl.add_self_loop(g)
g.ndata['feat'] = g.ndata['feat'].float() g.ndata["feat"] = g.ndata["feat"].float()
elif name == "ogbn-mag": elif name == "ogbn-mag":
# MAG is a heterogeneous graph. The task is to make prediction for # MAG is a heterogeneous graph. The task is to make prediction for
# paper nodes # paper nodes
...@@ -75,16 +80,18 @@ def load_dataset(name, device): ...@@ -75,16 +80,18 @@ def load_dataset(name, device):
test_nid = test_nid["paper"] test_nid = test_nid["paper"]
g = convert_mag_to_homograph(g, device) g = convert_mag_to_homograph(g, device)
else: else:
g.ndata['feat'] = g.ndata['feat'].float() g.ndata["feat"] = g.ndata["feat"].float()
n_classes = dataset.num_classes n_classes = dataset.num_classes
labels = labels.squeeze() labels = labels.squeeze()
evaluator = get_ogb_evaluator(name) evaluator = get_ogb_evaluator(name)
print(f"# Nodes: {g.number_of_nodes()}\n" print(
f"# Edges: {g.number_of_edges()}\n" f"# Nodes: {g.number_of_nodes()}\n"
f"# Train: {len(train_nid)}\n" f"# Edges: {g.number_of_edges()}\n"
f"# Val: {len(val_nid)}\n" f"# Train: {len(train_nid)}\n"
f"# Test: {len(test_nid)}\n" f"# Val: {len(val_nid)}\n"
f"# Classes: {n_classes}") f"# Test: {len(test_nid)}\n"
f"# Classes: {n_classes}"
)
return g, labels, n_classes, train_nid, val_nid, test_nid, evaluator return g, labels, n_classes, train_nid, val_nid, test_nid, evaluator
import argparse import argparse
import time import time
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from dataset import load_dataset
import dgl import dgl
import dgl.function as fn import dgl.function as fn
from dataset import load_dataset
class FeedForwardNet(nn.Module): class FeedForwardNet(nn.Module):
...@@ -40,8 +42,16 @@ class FeedForwardNet(nn.Module): ...@@ -40,8 +42,16 @@ class FeedForwardNet(nn.Module):
class SIGN(nn.Module): class SIGN(nn.Module):
def __init__(self, in_feats, hidden, out_feats, num_hops, n_layers, def __init__(
dropout, input_drop): self,
in_feats,
hidden,
out_feats,
num_hops,
n_layers,
dropout,
input_drop,
):
super(SIGN, self).__init__() super(SIGN, self).__init__()
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
self.prelu = nn.PReLU() self.prelu = nn.PReLU()
...@@ -49,9 +59,11 @@ class SIGN(nn.Module): ...@@ -49,9 +59,11 @@ class SIGN(nn.Module):
self.input_drop = nn.Dropout(input_drop) self.input_drop = nn.Dropout(input_drop)
for hop in range(num_hops): for hop in range(num_hops):
self.inception_ffs.append( self.inception_ffs.append(
FeedForwardNet(in_feats, hidden, hidden, n_layers, dropout)) FeedForwardNet(in_feats, hidden, hidden, n_layers, dropout)
self.project = FeedForwardNet(num_hops * hidden, hidden, out_feats, )
n_layers, dropout) self.project = FeedForwardNet(
num_hops * hidden, hidden, out_feats, n_layers, dropout
)
def forward(self, feats): def forward(self, feats):
feats = [self.input_drop(feat) for feat in feats] feats = [self.input_drop(feat) for feat in feats]
...@@ -72,7 +84,7 @@ def get_n_params(model): ...@@ -72,7 +84,7 @@ def get_n_params(model):
for p in list(model.parameters()): for p in list(model.parameters()):
nn = 1 nn = 1
for s in list(p.size()): for s in list(p.size()):
nn = nn*s nn = nn * s
pp += nn pp += nn
return pp return pp
...@@ -84,8 +96,9 @@ def neighbor_average_features(g, args): ...@@ -84,8 +96,9 @@ def neighbor_average_features(g, args):
print("Compute neighbor-averaged feats") print("Compute neighbor-averaged feats")
g.ndata["feat_0"] = g.ndata["feat"] g.ndata["feat_0"] = g.ndata["feat"]
for hop in range(1, args.R + 1): for hop in range(1, args.R + 1):
g.update_all(fn.copy_u(f"feat_{hop-1}", "msg"), g.update_all(
fn.mean("msg", f"feat_{hop}")) fn.copy_u(f"feat_{hop-1}", "msg"), fn.mean("msg", f"feat_{hop}")
)
res = [] res = []
for hop in range(args.R + 1): for hop in range(args.R + 1):
res.append(g.ndata.pop(f"feat_{hop}")) res.append(g.ndata.pop(f"feat_{hop}"))
...@@ -98,8 +111,9 @@ def neighbor_average_features(g, args): ...@@ -98,8 +111,9 @@ def neighbor_average_features(g, args):
num_target = target_mask.sum().item() num_target = target_mask.sum().item()
new_res = [] new_res = []
for x in res: for x in res:
feat = torch.zeros((num_target,) + x.shape[1:], feat = torch.zeros(
dtype=x.dtype, device=x.device) (num_target,) + x.shape[1:], dtype=x.dtype, device=x.device
)
feat[target_ids] = x[target_mask] feat[target_ids] = x[target_mask]
new_res.append(feat) new_res.append(feat)
res = new_res res = new_res
...@@ -112,15 +126,23 @@ def prepare_data(device, args): ...@@ -112,15 +126,23 @@ def prepare_data(device, args):
""" """
data = load_dataset(args.dataset, device) data = load_dataset(args.dataset, device)
g, labels, n_classes, train_nid, val_nid, test_nid, evaluator = data g, labels, n_classes, train_nid, val_nid, test_nid, evaluator = data
in_feats = g.ndata['feat'].shape[1] in_feats = g.ndata["feat"].shape[1]
feats = neighbor_average_features(g, args) feats = neighbor_average_features(g, args)
labels = labels.to(device) labels = labels.to(device)
# move to device # move to device
train_nid = train_nid.to(device) train_nid = train_nid.to(device)
val_nid = val_nid.to(device) val_nid = val_nid.to(device)
test_nid = test_nid.to(device) test_nid = test_nid.to(device)
return feats, labels, in_feats, n_classes, \ return (
train_nid, val_nid, test_nid, evaluator feats,
labels,
in_feats,
n_classes,
train_nid,
val_nid,
test_nid,
evaluator,
)
def train(model, feats, labels, loss_fcn, optimizer, train_loader): def train(model, feats, labels, loss_fcn, optimizer, train_loader):
...@@ -134,8 +156,9 @@ def train(model, feats, labels, loss_fcn, optimizer, train_loader): ...@@ -134,8 +156,9 @@ def train(model, feats, labels, loss_fcn, optimizer, train_loader):
optimizer.step() optimizer.step()
def test(model, feats, labels, test_loader, evaluator, def test(
train_nid, val_nid, test_nid): model, feats, labels, test_loader, evaluator, train_nid, val_nid, test_nid
):
model.eval() model.eval()
device = labels.device device = labels.device
preds = [] preds = []
...@@ -151,24 +174,44 @@ def test(model, feats, labels, test_loader, evaluator, ...@@ -151,24 +174,44 @@ def test(model, feats, labels, test_loader, evaluator,
def run(args, data, device): def run(args, data, device):
feats, labels, in_size, num_classes, \ (
train_nid, val_nid, test_nid, evaluator = data feats,
labels,
in_size,
num_classes,
train_nid,
val_nid,
test_nid,
evaluator,
) = data
train_loader = torch.utils.data.DataLoader( train_loader = torch.utils.data.DataLoader(
train_nid, batch_size=args.batch_size, shuffle=True, drop_last=False) train_nid, batch_size=args.batch_size, shuffle=True, drop_last=False
)
test_loader = torch.utils.data.DataLoader( test_loader = torch.utils.data.DataLoader(
torch.arange(labels.shape[0]), batch_size=args.eval_batch_size, torch.arange(labels.shape[0]),
shuffle=False, drop_last=False) batch_size=args.eval_batch_size,
shuffle=False,
drop_last=False,
)
# Initialize model and optimizer for each run # Initialize model and optimizer for each run
num_hops = args.R + 1 num_hops = args.R + 1
model = SIGN(in_size, args.num_hidden, num_classes, num_hops, model = SIGN(
args.ff_layer, args.dropout, args.input_dropout) in_size,
args.num_hidden,
num_classes,
num_hops,
args.ff_layer,
args.dropout,
args.input_dropout,
)
model = model.to(device) model = model.to(device)
print("# Params:", get_n_params(model)) print("# Params:", get_n_params(model))
loss_fcn = nn.CrossEntropyLoss() loss_fcn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, optimizer = torch.optim.Adam(
weight_decay=args.weight_decay) model.parameters(), lr=args.lr, weight_decay=args.weight_decay
)
# Start training # Start training
best_epoch = 0 best_epoch = 0
...@@ -180,8 +223,16 @@ def run(args, data, device): ...@@ -180,8 +223,16 @@ def run(args, data, device):
if epoch % args.eval_every == 0: if epoch % args.eval_every == 0:
with torch.no_grad(): with torch.no_grad():
acc = test(model, feats, labels, test_loader, evaluator, acc = test(
train_nid, val_nid, test_nid) model,
feats,
labels,
test_loader,
evaluator,
train_nid,
val_nid,
test_nid,
)
end = time.time() end = time.time()
log = "Epoch {}, Time(s): {:.4f}, ".format(epoch, end - start) log = "Epoch {}, Time(s): {:.4f}, ".format(epoch, end - start)
log += "Acc: Train {:.4f}, Val {:.4f}, Test {:.4f}".format(*acc) log += "Acc: Train {:.4f}, Val {:.4f}, Test {:.4f}".format(*acc)
...@@ -191,8 +242,11 @@ def run(args, data, device): ...@@ -191,8 +242,11 @@ def run(args, data, device):
best_val = acc[1] best_val = acc[1]
best_test = acc[2] best_test = acc[2]
print("Best Epoch {}, Val {:.4f}, Test {:.4f}".format( print(
best_epoch, best_val, best_test)) "Best Epoch {}, Val {:.4f}, Test {:.4f}".format(
best_epoch, best_val, best_test
)
)
return best_val, best_test return best_val, best_test
...@@ -212,34 +266,51 @@ def main(args): ...@@ -212,34 +266,51 @@ def main(args):
val_accs.append(best_val) val_accs.append(best_val)
test_accs.append(best_test) test_accs.append(best_test)
print(f"Average val accuracy: {np.mean(val_accs):.4f}, " print(
f"std: {np.std(val_accs):.4f}") f"Average val accuracy: {np.mean(val_accs):.4f}, "
print(f"Average test accuracy: {np.mean(test_accs):.4f}, " f"std: {np.std(val_accs):.4f}"
f"std: {np.std(test_accs):.4f}") )
print(
f"Average test accuracy: {np.mean(test_accs):.4f}, "
f"std: {np.std(test_accs):.4f}"
)
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="SIGN") parser = argparse.ArgumentParser(description="SIGN")
parser.add_argument("--num-epochs", type=int, default=1000) parser.add_argument("--num-epochs", type=int, default=1000)
parser.add_argument("--num-hidden", type=int, default=512) parser.add_argument("--num-hidden", type=int, default=512)
parser.add_argument("--R", type=int, default=5, parser.add_argument("--R", type=int, default=5, help="number of hops")
help="number of hops")
parser.add_argument("--lr", type=float, default=0.001) parser.add_argument("--lr", type=float, default=0.001)
parser.add_argument("--dataset", type=str, default="ogbn-mag") parser.add_argument("--dataset", type=str, default="ogbn-mag")
parser.add_argument("--dropout", type=float, default=0.5, parser.add_argument(
help="dropout on activation") "--dropout", type=float, default=0.5, help="dropout on activation"
)
parser.add_argument("--gpu", type=int, default=0) parser.add_argument("--gpu", type=int, default=0)
parser.add_argument("--weight-decay", type=float, default=0) parser.add_argument("--weight-decay", type=float, default=0)
parser.add_argument("--eval-every", type=int, default=10) parser.add_argument("--eval-every", type=int, default=10)
parser.add_argument("--batch-size", type=int, default=50000) parser.add_argument("--batch-size", type=int, default=50000)
parser.add_argument("--eval-batch-size", type=int, default=100000, parser.add_argument(
help="evaluation batch size") "--eval-batch-size",
parser.add_argument("--ff-layer", type=int, default=2, type=int,
help="number of feed-forward layers") default=100000,
parser.add_argument("--input-dropout", type=float, default=0, help="evaluation batch size",
help="dropout on input features") )
parser.add_argument("--num-runs", type=int, default=10, parser.add_argument(
help="number of times to repeat the experiment") "--ff-layer", type=int, default=2, help="number of feed-forward layers"
)
parser.add_argument(
"--input-dropout",
type=float,
default=0,
help="dropout on input features",
)
parser.add_argument(
"--num-runs",
type=int,
default=10,
help="number of times to repeat the experiment",
)
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
......
import ogb import argparse
from ogb.lsc import MAG240MDataset import os
import tqdm
import numpy as np import numpy as np
import ogb
import torch import torch
import tqdm
from ogb.lsc import MAG240MDataset
import dgl import dgl
import dgl.function as fn import dgl.function as fn
import argparse
import os
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--rootdir', type=str, default='.', help='Directory to download the OGB dataset.') parser.add_argument(
parser.add_argument('--author-output-path', type=str, help='Path to store the author features.') "--rootdir",
parser.add_argument('--inst-output-path', type=str, type=str,
help='Path to store the institution features.') default=".",
parser.add_argument('--graph-output-path', type=str, help='Path to store the graph.') help="Directory to download the OGB dataset.",
parser.add_argument('--graph-format', type=str, default='csc', help='Graph format (coo, csr or csc).') )
parser.add_argument('--graph-as-homogeneous', action='store_true', help='Store the graph as DGL homogeneous graph.') parser.add_argument(
parser.add_argument('--full-output-path', type=str, "--author-output-path", type=str, help="Path to store the author features."
help='Path to store features of all nodes. Effective only when graph is homogeneous.') )
parser.add_argument(
"--inst-output-path",
type=str,
help="Path to store the institution features.",
)
parser.add_argument(
"--graph-output-path", type=str, help="Path to store the graph."
)
parser.add_argument(
"--graph-format",
type=str,
default="csc",
help="Graph format (coo, csr or csc).",
)
parser.add_argument(
"--graph-as-homogeneous",
action="store_true",
help="Store the graph as DGL homogeneous graph.",
)
parser.add_argument(
"--full-output-path",
type=str,
help="Path to store features of all nodes. Effective only when graph is homogeneous.",
)
args = parser.parse_args() args = parser.parse_args()
print('Building graph') print("Building graph")
dataset = MAG240MDataset(root=args.rootdir) dataset = MAG240MDataset(root=args.rootdir)
ei_writes = dataset.edge_index('author', 'writes', 'paper') ei_writes = dataset.edge_index("author", "writes", "paper")
ei_cites = dataset.edge_index('paper', 'paper') ei_cites = dataset.edge_index("paper", "paper")
ei_affiliated = dataset.edge_index('author', 'institution') ei_affiliated = dataset.edge_index("author", "institution")
# We sort the nodes starting with the papers, then the authors, then the institutions. # We sort the nodes starting with the papers, then the authors, then the institutions.
author_offset = 0 author_offset = 0
inst_offset = author_offset + dataset.num_authors inst_offset = author_offset + dataset.num_authors
paper_offset = inst_offset + dataset.num_institutions paper_offset = inst_offset + dataset.num_institutions
g = dgl.heterograph({ g = dgl.heterograph(
('author', 'write', 'paper'): (ei_writes[0], ei_writes[1]), {
('paper', 'write-by', 'author'): (ei_writes[1], ei_writes[0]), ("author", "write", "paper"): (ei_writes[0], ei_writes[1]),
('author', 'affiliate-with', 'institution'): (ei_affiliated[0], ei_affiliated[1]), ("paper", "write-by", "author"): (ei_writes[1], ei_writes[0]),
('institution', 'affiliate', 'author'): (ei_affiliated[1], ei_affiliated[0]), ("author", "affiliate-with", "institution"): (
('paper', 'cite', 'paper'): (np.concatenate([ei_cites[0], ei_cites[1]]), np.concatenate([ei_cites[1], ei_cites[0]])) ei_affiliated[0],
}) ei_affiliated[1],
),
("institution", "affiliate", "author"): (
ei_affiliated[1],
ei_affiliated[0],
),
("paper", "cite", "paper"): (
np.concatenate([ei_cites[0], ei_cites[1]]),
np.concatenate([ei_cites[1], ei_cites[0]]),
),
}
)
paper_feat = dataset.paper_feat paper_feat = dataset.paper_feat
author_feat = np.memmap(args.author_output_path, mode='w+', dtype='float16', shape=(dataset.num_authors, dataset.num_paper_features)) author_feat = np.memmap(
inst_feat = np.memmap(args.inst_output_path, mode='w+', dtype='float16', shape=(dataset.num_institutions, dataset.num_paper_features)) args.author_output_path,
mode="w+",
dtype="float16",
shape=(dataset.num_authors, dataset.num_paper_features),
)
inst_feat = np.memmap(
args.inst_output_path,
mode="w+",
dtype="float16",
shape=(dataset.num_institutions, dataset.num_paper_features),
)
# Iteratively process author features along the feature dimension. # Iteratively process author features along the feature dimension.
BLOCK_COLS = 16 BLOCK_COLS = 16
with tqdm.trange(0, dataset.num_paper_features, BLOCK_COLS) as tq: with tqdm.trange(0, dataset.num_paper_features, BLOCK_COLS) as tq:
for start in tq: for start in tq:
tq.set_postfix_str('Reading paper features...') tq.set_postfix_str("Reading paper features...")
g.nodes['paper'].data['x'] = torch.FloatTensor(paper_feat[:, start:start + BLOCK_COLS].astype('float32')) g.nodes["paper"].data["x"] = torch.FloatTensor(
paper_feat[:, start : start + BLOCK_COLS].astype("float32")
)
# Compute author features... # Compute author features...
tq.set_postfix_str('Computing author features...') tq.set_postfix_str("Computing author features...")
g.update_all(fn.copy_u('x', 'm'), fn.mean('m', 'x'), etype='write-by') g.update_all(fn.copy_u("x", "m"), fn.mean("m", "x"), etype="write-by")
# Then institution features... # Then institution features...
tq.set_postfix_str('Computing institution features...') tq.set_postfix_str("Computing institution features...")
g.update_all(fn.copy_u('x', 'm'), fn.mean('m', 'x'), etype='affiliate-with') g.update_all(
tq.set_postfix_str('Writing author features...') fn.copy_u("x", "m"), fn.mean("m", "x"), etype="affiliate-with"
author_feat[:, start:start + BLOCK_COLS] = g.nodes['author'].data['x'].numpy().astype('float16') )
tq.set_postfix_str('Writing institution features...') tq.set_postfix_str("Writing author features...")
inst_feat[:, start:start + BLOCK_COLS] = g.nodes['institution'].data['x'].numpy().astype('float16') author_feat[:, start : start + BLOCK_COLS] = (
del g.nodes['paper'].data['x'] g.nodes["author"].data["x"].numpy().astype("float16")
del g.nodes['author'].data['x'] )
del g.nodes['institution'].data['x'] tq.set_postfix_str("Writing institution features...")
inst_feat[:, start : start + BLOCK_COLS] = (
g.nodes["institution"].data["x"].numpy().astype("float16")
)
del g.nodes["paper"].data["x"]
del g.nodes["author"].data["x"]
del g.nodes["institution"].data["x"]
author_feat.flush() author_feat.flush()
inst_feat.flush() inst_feat.flush()
...@@ -73,34 +128,56 @@ if args.graph_as_homogeneous: ...@@ -73,34 +128,56 @@ if args.graph_as_homogeneous:
# DGL also ensures that the node types are sorted in ascending order. # DGL also ensures that the node types are sorted in ascending order.
assert torch.equal( assert torch.equal(
g.ndata[dgl.NTYPE], g.ndata[dgl.NTYPE],
torch.cat([torch.full((dataset.num_authors,), 0), torch.cat(
torch.full((dataset.num_institutions,), 1), [
torch.full((dataset.num_papers,), 2)])) torch.full((dataset.num_authors,), 0),
torch.full((dataset.num_institutions,), 1),
torch.full((dataset.num_papers,), 2),
]
),
)
assert torch.equal( assert torch.equal(
g.ndata[dgl.NID], g.ndata[dgl.NID],
torch.cat([torch.arange(dataset.num_authors), torch.cat(
torch.arange(dataset.num_institutions), [
torch.arange(dataset.num_papers)])) torch.arange(dataset.num_authors),
g.edata['etype'] = g.edata[dgl.ETYPE].byte() torch.arange(dataset.num_institutions),
torch.arange(dataset.num_papers),
]
),
)
g.edata["etype"] = g.edata[dgl.ETYPE].byte()
del g.edata[dgl.ETYPE] del g.edata[dgl.ETYPE]
del g.ndata[dgl.NTYPE] del g.ndata[dgl.NTYPE]
del g.ndata[dgl.NID] del g.ndata[dgl.NID]
# Process feature # Process feature
full_feat = np.memmap( full_feat = np.memmap(
args.full_output_path, mode='w+', dtype='float16', args.full_output_path,
shape=(dataset.num_authors + dataset.num_institutions + dataset.num_papers, dataset.num_paper_features)) mode="w+",
dtype="float16",
shape=(
dataset.num_authors + dataset.num_institutions + dataset.num_papers,
dataset.num_paper_features,
),
)
BLOCK_ROWS = 100000 BLOCK_ROWS = 100000
for start in tqdm.trange(0, dataset.num_authors, BLOCK_ROWS): for start in tqdm.trange(0, dataset.num_authors, BLOCK_ROWS):
end = min(dataset.num_authors, start + BLOCK_ROWS) end = min(dataset.num_authors, start + BLOCK_ROWS)
full_feat[author_offset + start:author_offset + end] = author_feat[start:end] full_feat[author_offset + start : author_offset + end] = author_feat[
start:end
]
for start in tqdm.trange(0, dataset.num_institutions, BLOCK_ROWS): for start in tqdm.trange(0, dataset.num_institutions, BLOCK_ROWS):
end = min(dataset.num_institutions, start + BLOCK_ROWS) end = min(dataset.num_institutions, start + BLOCK_ROWS)
full_feat[inst_offset + start:inst_offset + end] = inst_feat[start:end] full_feat[inst_offset + start : inst_offset + end] = inst_feat[
start:end
]
for start in tqdm.trange(0, dataset.num_papers, BLOCK_ROWS): for start in tqdm.trange(0, dataset.num_papers, BLOCK_ROWS):
end = min(dataset.num_papers, start + BLOCK_ROWS) end = min(dataset.num_papers, start + BLOCK_ROWS)
full_feat[paper_offset + start:paper_offset + end] = paper_feat[start:end] full_feat[paper_offset + start : paper_offset + end] = paper_feat[
start:end
]
# Convert the graph to the given format and save. (The RGAT baseline needs CSC graph) # Convert the graph to the given format and save. (The RGAT baseline needs CSC graph)
g = g.formats(args.graph_format) g = g.formats(args.graph_format)
dgl.save_graphs(args.graph_output_path, g) dgl.save_graphs(args.graph_output_path, g)
#!/usr/bin/env python #!/usr/bin/env python
# coding: utf-8 # coding: utf-8
import argparse
import time
import numpy as np
import ogb import ogb
from ogb.lsc import MAG240MDataset, MAG240MEvaluator
import dgl
import torch import torch
import numpy as np import torch.nn as nn
import time import torch.nn.functional as F
import tqdm import tqdm
from ogb.lsc import MAG240MDataset, MAG240MEvaluator
import dgl
import dgl.function as fn import dgl.function as fn
import numpy as np
import dgl.nn as dglnn import dgl.nn as dglnn
import torch.nn as nn
import torch.nn.functional as F
import argparse
class RGAT(nn.Module): class RGAT(nn.Module):
def __init__(self, in_channels, out_channels, hidden_channels, num_etypes, num_layers, num_heads, dropout, pred_ntype): def __init__(
self,
in_channels,
out_channels,
hidden_channels,
num_etypes,
num_layers,
num_heads,
dropout,
pred_ntype,
):
super().__init__() super().__init__()
self.convs = nn.ModuleList() self.convs = nn.ModuleList()
self.norms = nn.ModuleList() self.norms = nn.ModuleList()
self.skips = nn.ModuleList() self.skips = nn.ModuleList()
self.convs.append(nn.ModuleList([ self.convs.append(
dglnn.GATConv(in_channels, hidden_channels // num_heads, num_heads, allow_zero_in_degree=True) nn.ModuleList(
for _ in range(num_etypes) [
])) dglnn.GATConv(
in_channels,
hidden_channels // num_heads,
num_heads,
allow_zero_in_degree=True,
)
for _ in range(num_etypes)
]
)
)
self.norms.append(nn.BatchNorm1d(hidden_channels)) self.norms.append(nn.BatchNorm1d(hidden_channels))
self.skips.append(nn.Linear(in_channels, hidden_channels)) self.skips.append(nn.Linear(in_channels, hidden_channels))
for _ in range(num_layers - 1): for _ in range(num_layers - 1):
self.convs.append(nn.ModuleList([ self.convs.append(
dglnn.GATConv(hidden_channels, hidden_channels // num_heads, num_heads, allow_zero_in_degree=True) nn.ModuleList(
for _ in range(num_etypes) [
])) dglnn.GATConv(
hidden_channels,
hidden_channels // num_heads,
num_heads,
allow_zero_in_degree=True,
)
for _ in range(num_etypes)
]
)
)
self.norms.append(nn.BatchNorm1d(hidden_channels)) self.norms.append(nn.BatchNorm1d(hidden_channels))
self.skips.append(nn.Linear(hidden_channels, hidden_channels)) self.skips.append(nn.Linear(hidden_channels, hidden_channels))
self.mlp = nn.Sequential( self.mlp = nn.Sequential(
nn.Linear(hidden_channels, hidden_channels), nn.Linear(hidden_channels, hidden_channels),
nn.BatchNorm1d(hidden_channels), nn.BatchNorm1d(hidden_channels),
nn.ReLU(), nn.ReLU(),
nn.Dropout(dropout), nn.Dropout(dropout),
nn.Linear(hidden_channels, out_channels) nn.Linear(hidden_channels, out_channels),
) )
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
self.hidden_channels = hidden_channels self.hidden_channels = hidden_channels
self.pred_ntype = pred_ntype self.pred_ntype = pred_ntype
self.num_etypes = num_etypes self.num_etypes = num_etypes
def forward(self, mfgs, x): def forward(self, mfgs, x):
for i in range(len(mfgs)): for i in range(len(mfgs)):
mfg = mfgs[i] mfg = mfgs[i]
x_dst = x[:mfg.num_dst_nodes()] x_dst = x[: mfg.num_dst_nodes()]
n_src = mfg.num_src_nodes() n_src = mfg.num_src_nodes()
n_dst = mfg.num_dst_nodes() n_dst = mfg.num_dst_nodes()
mfg = dgl.block_to_graph(mfg) mfg = dgl.block_to_graph(mfg)
x_skip = self.skips[i](x_dst) x_skip = self.skips[i](x_dst)
for j in range(self.num_etypes): for j in range(self.num_etypes):
subg = mfg.edge_subgraph(mfg.edata['etype'] == j, relabel_nodes=False) subg = mfg.edge_subgraph(
x_skip += self.convs[i][j](subg, (x, x_dst)).view(-1, self.hidden_channels) mfg.edata["etype"] == j, relabel_nodes=False
)
x_skip += self.convs[i][j](subg, (x, x_dst)).view(
-1, self.hidden_channels
)
x = self.norms[i](x_skip) x = self.norms[i](x_skip)
x = F.elu(x) x = F.elu(x)
x = self.dropout(x) x = self.dropout(x)
...@@ -76,27 +110,34 @@ class ExternalNodeCollator(dgl.dataloading.NodeCollator): ...@@ -76,27 +110,34 @@ class ExternalNodeCollator(dgl.dataloading.NodeCollator):
def collate(self, items): def collate(self, items):
input_nodes, output_nodes, mfgs = super().collate(items) input_nodes, output_nodes, mfgs = super().collate(items)
# Copy input features # Copy input features
mfgs[0].srcdata['x'] = torch.FloatTensor(self.feats[input_nodes]) mfgs[0].srcdata["x"] = torch.FloatTensor(self.feats[input_nodes])
mfgs[-1].dstdata['y'] = torch.LongTensor(self.label[output_nodes - self.offset]) mfgs[-1].dstdata["y"] = torch.LongTensor(
self.label[output_nodes - self.offset]
)
return input_nodes, output_nodes, mfgs return input_nodes, output_nodes, mfgs
def train(args, dataset, g, feats, paper_offset): def train(args, dataset, g, feats, paper_offset):
print('Loading masks and labels') print("Loading masks and labels")
train_idx = torch.LongTensor(dataset.get_idx_split('train')) + paper_offset train_idx = torch.LongTensor(dataset.get_idx_split("train")) + paper_offset
valid_idx = torch.LongTensor(dataset.get_idx_split('valid')) + paper_offset valid_idx = torch.LongTensor(dataset.get_idx_split("valid")) + paper_offset
label = dataset.paper_label label = dataset.paper_label
print('Initializing dataloader...') print("Initializing dataloader...")
sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 25]) sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 25])
train_collator = ExternalNodeCollator(g, train_idx, sampler, paper_offset, feats, label) train_collator = ExternalNodeCollator(
valid_collator = ExternalNodeCollator(g, valid_idx, sampler, paper_offset, feats, label) g, train_idx, sampler, paper_offset, feats, label
)
valid_collator = ExternalNodeCollator(
g, valid_idx, sampler, paper_offset, feats, label
)
train_dataloader = torch.utils.data.DataLoader( train_dataloader = torch.utils.data.DataLoader(
train_collator.dataset, train_collator.dataset,
batch_size=1024, batch_size=1024,
shuffle=True, shuffle=True,
drop_last=False, drop_last=False,
collate_fn=train_collator.collate, collate_fn=train_collator.collate,
num_workers=4 num_workers=4,
) )
valid_dataloader = torch.utils.data.DataLoader( valid_dataloader = torch.utils.data.DataLoader(
valid_collator.dataset, valid_collator.dataset,
...@@ -104,11 +145,20 @@ def train(args, dataset, g, feats, paper_offset): ...@@ -104,11 +145,20 @@ def train(args, dataset, g, feats, paper_offset):
shuffle=True, shuffle=True,
drop_last=False, drop_last=False,
collate_fn=valid_collator.collate, collate_fn=valid_collator.collate,
num_workers=2 num_workers=2,
) )
print('Initializing model...') print("Initializing model...")
model = RGAT(dataset.num_paper_features, dataset.num_classes, 1024, 5, 2, 4, 0.5, 'paper').cuda() model = RGAT(
dataset.num_paper_features,
dataset.num_classes,
1024,
5,
2,
4,
0.5,
"paper",
).cuda()
opt = torch.optim.Adam(model.parameters(), lr=0.001) opt = torch.optim.Adam(model.parameters(), lr=0.001)
sched = torch.optim.lr_scheduler.StepLR(opt, step_size=25, gamma=0.25) sched = torch.optim.lr_scheduler.StepLR(opt, step_size=25, gamma=0.25)
...@@ -118,114 +168,170 @@ def train(args, dataset, g, feats, paper_offset): ...@@ -118,114 +168,170 @@ def train(args, dataset, g, feats, paper_offset):
model.train() model.train()
with tqdm.tqdm(train_dataloader) as tq: with tqdm.tqdm(train_dataloader) as tq:
for i, (input_nodes, output_nodes, mfgs) in enumerate(tq): for i, (input_nodes, output_nodes, mfgs) in enumerate(tq):
mfgs = [g.to('cuda') for g in mfgs] mfgs = [g.to("cuda") for g in mfgs]
x = mfgs[0].srcdata['x'] x = mfgs[0].srcdata["x"]
y = mfgs[-1].dstdata['y'] y = mfgs[-1].dstdata["y"]
y_hat = model(mfgs, x) y_hat = model(mfgs, x)
loss = F.cross_entropy(y_hat, y) loss = F.cross_entropy(y_hat, y)
opt.zero_grad() opt.zero_grad()
loss.backward() loss.backward()
opt.step() opt.step()
acc = (y_hat.argmax(1) == y).float().mean() acc = (y_hat.argmax(1) == y).float().mean()
tq.set_postfix({'loss': '%.4f' % loss.item(), 'acc': '%.4f' % acc.item()}, refresh=False) tq.set_postfix(
{"loss": "%.4f" % loss.item(), "acc": "%.4f" % acc.item()},
refresh=False,
)
model.eval() model.eval()
correct = total = 0 correct = total = 0
for i, (input_nodes, output_nodes, mfgs) in enumerate(tqdm.tqdm(valid_dataloader)): for i, (input_nodes, output_nodes, mfgs) in enumerate(
tqdm.tqdm(valid_dataloader)
):
with torch.no_grad(): with torch.no_grad():
mfgs = [g.to('cuda') for g in mfgs] mfgs = [g.to("cuda") for g in mfgs]
x = mfgs[0].srcdata['x'] x = mfgs[0].srcdata["x"]
y = mfgs[-1].dstdata['y'] y = mfgs[-1].dstdata["y"]
y_hat = model(mfgs, x) y_hat = model(mfgs, x)
correct += (y_hat.argmax(1) == y).sum().item() correct += (y_hat.argmax(1) == y).sum().item()
total += y_hat.shape[0] total += y_hat.shape[0]
acc = correct / total acc = correct / total
print('Validation accuracy:', acc) print("Validation accuracy:", acc)
sched.step() sched.step()
if best_acc < acc: if best_acc < acc:
best_acc = acc best_acc = acc
print('Updating best model...') print("Updating best model...")
torch.save(model.state_dict(), args.model_path) torch.save(model.state_dict(), args.model_path)
def test(args, dataset, g, feats, paper_offset): def test(args, dataset, g, feats, paper_offset):
print('Loading masks and labels...') print("Loading masks and labels...")
valid_idx = torch.LongTensor(dataset.get_idx_split('valid')) + paper_offset valid_idx = torch.LongTensor(dataset.get_idx_split("valid")) + paper_offset
test_idx = torch.LongTensor(dataset.get_idx_split('test')) + paper_offset test_idx = torch.LongTensor(dataset.get_idx_split("test")) + paper_offset
label = dataset.paper_label label = dataset.paper_label
print('Initializing data loader...') print("Initializing data loader...")
sampler = dgl.dataloading.MultiLayerNeighborSampler([160, 160]) sampler = dgl.dataloading.MultiLayerNeighborSampler([160, 160])
valid_collator = ExternalNodeCollator(g, valid_idx, sampler, paper_offset, feats, label) valid_collator = ExternalNodeCollator(
g, valid_idx, sampler, paper_offset, feats, label
)
valid_dataloader = torch.utils.data.DataLoader( valid_dataloader = torch.utils.data.DataLoader(
valid_collator.dataset, valid_collator.dataset,
batch_size=16, batch_size=16,
shuffle=False, shuffle=False,
drop_last=False, drop_last=False,
collate_fn=valid_collator.collate, collate_fn=valid_collator.collate,
num_workers=2 num_workers=2,
)
test_collator = ExternalNodeCollator(
g, test_idx, sampler, paper_offset, feats, label
) )
test_collator = ExternalNodeCollator(g, test_idx, sampler, paper_offset, feats, label)
test_dataloader = torch.utils.data.DataLoader( test_dataloader = torch.utils.data.DataLoader(
test_collator.dataset, test_collator.dataset,
batch_size=16, batch_size=16,
shuffle=False, shuffle=False,
drop_last=False, drop_last=False,
collate_fn=test_collator.collate, collate_fn=test_collator.collate,
num_workers=4 num_workers=4,
) )
print('Loading model...') print("Loading model...")
model = RGAT(dataset.num_paper_features, dataset.num_classes, 1024, 5, 2, 4, 0.5, 'paper').cuda() model = RGAT(
dataset.num_paper_features,
dataset.num_classes,
1024,
5,
2,
4,
0.5,
"paper",
).cuda()
model.load_state_dict(torch.load(args.model_path)) model.load_state_dict(torch.load(args.model_path))
model.eval() model.eval()
correct = total = 0 correct = total = 0
for i, (input_nodes, output_nodes, mfgs) in enumerate(tqdm.tqdm(valid_dataloader)): for i, (input_nodes, output_nodes, mfgs) in enumerate(
tqdm.tqdm(valid_dataloader)
):
with torch.no_grad(): with torch.no_grad():
mfgs = [g.to('cuda') for g in mfgs] mfgs = [g.to("cuda") for g in mfgs]
x = mfgs[0].srcdata['x'] x = mfgs[0].srcdata["x"]
y = mfgs[-1].dstdata['y'] y = mfgs[-1].dstdata["y"]
y_hat = model(mfgs, x) y_hat = model(mfgs, x)
correct += (y_hat.argmax(1) == y).sum().item() correct += (y_hat.argmax(1) == y).sum().item()
total += y_hat.shape[0] total += y_hat.shape[0]
acc = correct / total acc = correct / total
print('Validation accuracy:', acc) print("Validation accuracy:", acc)
evaluator = MAG240MEvaluator() evaluator = MAG240MEvaluator()
y_preds = [] y_preds = []
for i, (input_nodes, output_nodes, mfgs) in enumerate(tqdm.tqdm(test_dataloader)): for i, (input_nodes, output_nodes, mfgs) in enumerate(
tqdm.tqdm(test_dataloader)
):
with torch.no_grad(): with torch.no_grad():
mfgs = [g.to('cuda') for g in mfgs] mfgs = [g.to("cuda") for g in mfgs]
x = mfgs[0].srcdata['x'] x = mfgs[0].srcdata["x"]
y = mfgs[-1].dstdata['y'] y = mfgs[-1].dstdata["y"]
y_hat = model(mfgs, x) y_hat = model(mfgs, x)
y_preds.append(y_hat.argmax(1).cpu()) y_preds.append(y_hat.argmax(1).cpu())
evaluator.save_test_submission({'y_pred': torch.cat(y_preds)}, args.submission_path) evaluator.save_test_submission(
{"y_pred": torch.cat(y_preds)}, args.submission_path
)
if __name__ == '__main__': if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--rootdir', type=str, default='.', help='Directory to download the OGB dataset.') parser.add_argument(
parser.add_argument('--graph-path', type=str, default='./graph.dgl', help='Path to the graph.') "--rootdir",
parser.add_argument('--full-feature-path', type=str, default='./full.npy', type=str,
help='Path to the features of all nodes.') default=".",
parser.add_argument('--epochs', type=int, default=100, help='Number of epochs.') help="Directory to download the OGB dataset.",
parser.add_argument('--model-path', type=str, default='./model.pt', help='Path to store the best model.') )
parser.add_argument('--submission-path', type=str, default='./results', help='Submission directory.') parser.add_argument(
"--graph-path",
type=str,
default="./graph.dgl",
help="Path to the graph.",
)
parser.add_argument(
"--full-feature-path",
type=str,
default="./full.npy",
help="Path to the features of all nodes.",
)
parser.add_argument(
"--epochs", type=int, default=100, help="Number of epochs."
)
parser.add_argument(
"--model-path",
type=str,
default="./model.pt",
help="Path to store the best model.",
)
parser.add_argument(
"--submission-path",
type=str,
default="./results",
help="Submission directory.",
)
args = parser.parse_args() args = parser.parse_args()
dataset = MAG240MDataset(root=args.rootdir) dataset = MAG240MDataset(root=args.rootdir)
print('Loading graph') print("Loading graph")
(g,), _ = dgl.load_graphs(args.graph_path) (g,), _ = dgl.load_graphs(args.graph_path)
g = g.formats(['csc']) g = g.formats(["csc"])
print('Loading features') print("Loading features")
paper_offset = dataset.num_authors + dataset.num_institutions paper_offset = dataset.num_authors + dataset.num_institutions
num_nodes = paper_offset + dataset.num_papers num_nodes = paper_offset + dataset.num_papers
num_features = dataset.num_paper_features num_features = dataset.num_paper_features
feats = np.memmap(args.full_feature_path, mode='r', dtype='float16', shape=(num_nodes, num_features)) feats = np.memmap(
args.full_feature_path,
mode="r",
dtype="float16",
shape=(num_nodes, num_features),
)
if args.epochs != 0: if args.epochs != 0:
train(args, dataset, g, feats, paper_offset) train(args, dataset, g, feats, paper_offset)
......
#!/usr/bin/env python #!/usr/bin/env python
# coding: utf-8 # coding: utf-8
import argparse
import math import math
import sys
from collections import OrderedDict
from ogb.lsc import MAG240MDataset, MAG240MEvaluator
import dgl
import torch
import tqdm
import numpy as np import numpy as np
import dgl.nn as dglnn import torch
import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import argparse import tqdm
import torch.multiprocessing as mp from ogb.lsc import MAG240MDataset, MAG240MEvaluator
import sys
from torch.nn.parallel import DistributedDataParallel from torch.nn.parallel import DistributedDataParallel
from collections import OrderedDict
import dgl
import dgl.nn as dglnn
class RGAT(nn.Module): class RGAT(nn.Module):
def __init__(self, in_channels, out_channels, hidden_channels, num_etypes, num_layers, num_heads, dropout, def __init__(
pred_ntype): self,
in_channels,
out_channels,
hidden_channels,
num_etypes,
num_layers,
num_heads,
dropout,
pred_ntype,
):
super().__init__() super().__init__()
self.convs = nn.ModuleList() self.convs = nn.ModuleList()
self.norms = nn.ModuleList() self.norms = nn.ModuleList()
self.skips = nn.ModuleList() self.skips = nn.ModuleList()
self.convs.append(nn.ModuleList([ self.convs.append(
dglnn.GATConv(in_channels, hidden_channels // num_heads, num_heads, allow_zero_in_degree=True) nn.ModuleList(
for _ in range(num_etypes) [
])) dglnn.GATConv(
in_channels,
hidden_channels // num_heads,
num_heads,
allow_zero_in_degree=True,
)
for _ in range(num_etypes)
]
)
)
self.norms.append(nn.BatchNorm1d(hidden_channels)) self.norms.append(nn.BatchNorm1d(hidden_channels))
self.skips.append(nn.Linear(in_channels, hidden_channels)) self.skips.append(nn.Linear(in_channels, hidden_channels))
for _ in range(num_layers - 1): for _ in range(num_layers - 1):
self.convs.append(nn.ModuleList([ self.convs.append(
dglnn.GATConv(hidden_channels, hidden_channels // num_heads, num_heads, allow_zero_in_degree=True) nn.ModuleList(
for _ in range(num_etypes) [
])) dglnn.GATConv(
hidden_channels,
hidden_channels // num_heads,
num_heads,
allow_zero_in_degree=True,
)
for _ in range(num_etypes)
]
)
)
self.norms.append(nn.BatchNorm1d(hidden_channels)) self.norms.append(nn.BatchNorm1d(hidden_channels))
self.skips.append(nn.Linear(hidden_channels, hidden_channels)) self.skips.append(nn.Linear(hidden_channels, hidden_channels))
...@@ -44,7 +72,7 @@ class RGAT(nn.Module): ...@@ -44,7 +72,7 @@ class RGAT(nn.Module):
nn.BatchNorm1d(hidden_channels), nn.BatchNorm1d(hidden_channels),
nn.ReLU(), nn.ReLU(),
nn.Dropout(dropout), nn.Dropout(dropout),
nn.Linear(hidden_channels, out_channels) nn.Linear(hidden_channels, out_channels),
) )
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
...@@ -55,14 +83,18 @@ class RGAT(nn.Module): ...@@ -55,14 +83,18 @@ class RGAT(nn.Module):
def forward(self, mfgs, x): def forward(self, mfgs, x):
for i in range(len(mfgs)): for i in range(len(mfgs)):
mfg = mfgs[i] mfg = mfgs[i]
x_dst = x[:mfg.num_dst_nodes()] x_dst = x[: mfg.num_dst_nodes()]
n_src = mfg.num_src_nodes() n_src = mfg.num_src_nodes()
n_dst = mfg.num_dst_nodes() n_dst = mfg.num_dst_nodes()
mfg = dgl.block_to_graph(mfg) mfg = dgl.block_to_graph(mfg)
x_skip = self.skips[i](x_dst) x_skip = self.skips[i](x_dst)
for j in range(self.num_etypes): for j in range(self.num_etypes):
subg = mfg.edge_subgraph(mfg.edata['etype'] == j, relabel_nodes=False) subg = mfg.edge_subgraph(
x_skip += self.convs[i][j](subg, (x, x_dst)).view(-1, self.hidden_channels) mfg.edata["etype"] == j, relabel_nodes=False
)
x_skip += self.convs[i][j](subg, (x, x_dst)).view(
-1, self.hidden_channels
)
x = self.norms[i](x_skip) x = self.norms[i](x_skip)
x = F.elu(x) x = F.elu(x)
x = self.dropout(x) x = self.dropout(x)
...@@ -79,46 +111,65 @@ class ExternalNodeCollator(dgl.dataloading.NodeCollator): ...@@ -79,46 +111,65 @@ class ExternalNodeCollator(dgl.dataloading.NodeCollator):
def collate(self, items): def collate(self, items):
input_nodes, output_nodes, mfgs = super().collate(items) input_nodes, output_nodes, mfgs = super().collate(items)
# Copy input features # Copy input features
mfgs[0].srcdata['x'] = torch.FloatTensor(self.feats[input_nodes]) mfgs[0].srcdata["x"] = torch.FloatTensor(self.feats[input_nodes])
mfgs[-1].dstdata['y'] = torch.LongTensor(self.label[output_nodes - self.offset]) mfgs[-1].dstdata["y"] = torch.LongTensor(
self.label[output_nodes - self.offset]
)
return input_nodes, output_nodes, mfgs return input_nodes, output_nodes, mfgs
def train(proc_id, n_gpus, args, dataset, g, feats, paper_offset): def train(proc_id, n_gpus, args, dataset, g, feats, paper_offset):
dev_id = devices[proc_id] dev_id = devices[proc_id]
if n_gpus > 1: if n_gpus > 1:
dist_init_method = 'tcp://{master_ip}:{master_port}'.format( dist_init_method = "tcp://{master_ip}:{master_port}".format(
master_ip='127.0.0.1', master_port='12346') master_ip="127.0.0.1", master_port="12346"
)
world_size = n_gpus world_size = n_gpus
torch.distributed.init_process_group(backend='nccl', torch.distributed.init_process_group(
init_method=dist_init_method, backend="nccl",
world_size=world_size, init_method=dist_init_method,
rank=proc_id) world_size=world_size,
rank=proc_id,
)
torch.cuda.set_device(dev_id) torch.cuda.set_device(dev_id)
print('Loading masks and labels') print("Loading masks and labels")
train_idx = torch.LongTensor(dataset.get_idx_split('train')) + paper_offset train_idx = torch.LongTensor(dataset.get_idx_split("train")) + paper_offset
valid_idx = torch.LongTensor(dataset.get_idx_split('valid')) + paper_offset valid_idx = torch.LongTensor(dataset.get_idx_split("valid")) + paper_offset
label = dataset.paper_label label = dataset.paper_label
print('Initializing dataloader...') print("Initializing dataloader...")
sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 25]) sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 25])
train_collator = ExternalNodeCollator(g, train_idx, sampler, paper_offset, feats, label) train_collator = ExternalNodeCollator(
valid_collator = ExternalNodeCollator(g, valid_idx, sampler, paper_offset, feats, label) g, train_idx, sampler, paper_offset, feats, label
)
valid_collator = ExternalNodeCollator(
g, valid_idx, sampler, paper_offset, feats, label
)
# Necessary according to https://yangkky.github.io/2019/07/08/distributed-pytorch-tutorial.html # Necessary according to https://yangkky.github.io/2019/07/08/distributed-pytorch-tutorial.html
train_sampler = torch.utils.data.distributed.DistributedSampler( train_sampler = torch.utils.data.distributed.DistributedSampler(
train_collator.dataset, num_replicas=world_size, rank=proc_id, shuffle=True, drop_last=False) train_collator.dataset,
num_replicas=world_size,
rank=proc_id,
shuffle=True,
drop_last=False,
)
valid_sampler = torch.utils.data.distributed.DistributedSampler( valid_sampler = torch.utils.data.distributed.DistributedSampler(
valid_collator.dataset, num_replicas=world_size, rank=proc_id, shuffle=True, drop_last=False) valid_collator.dataset,
num_replicas=world_size,
rank=proc_id,
shuffle=True,
drop_last=False,
)
train_dataloader = torch.utils.data.DataLoader( train_dataloader = torch.utils.data.DataLoader(
train_collator.dataset, train_collator.dataset,
batch_size=1024, batch_size=1024,
collate_fn=train_collator.collate, collate_fn=train_collator.collate,
num_workers=4, num_workers=4,
sampler=train_sampler sampler=train_sampler,
) )
valid_dataloader = torch.utils.data.DataLoader( valid_dataloader = torch.utils.data.DataLoader(
...@@ -126,16 +177,27 @@ def train(proc_id, n_gpus, args, dataset, g, feats, paper_offset): ...@@ -126,16 +177,27 @@ def train(proc_id, n_gpus, args, dataset, g, feats, paper_offset):
batch_size=1024, batch_size=1024,
collate_fn=valid_collator.collate, collate_fn=valid_collator.collate,
num_workers=2, num_workers=2,
sampler=valid_sampler sampler=valid_sampler,
) )
print('Initializing model...') print("Initializing model...")
model = RGAT(dataset.num_paper_features, dataset.num_classes, 1024, 5, 2, 4, 0.5, 'paper').to(dev_id) model = RGAT(
dataset.num_paper_features,
dataset.num_classes,
1024,
5,
2,
4,
0.5,
"paper",
).to(dev_id)
# convert BN to SyncBatchNorm. see https://pytorch.org/docs/stable/generated/torch.nn.SyncBatchNorm.html # convert BN to SyncBatchNorm. see https://pytorch.org/docs/stable/generated/torch.nn.SyncBatchNorm.html
model = nn.SyncBatchNorm.convert_sync_batchnorm(model) model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
model = DistributedDataParallel(model, device_ids=[dev_id], output_device=dev_id) model = DistributedDataParallel(
model, device_ids=[dev_id], output_device=dev_id
)
opt = torch.optim.Adam(model.parameters(), lr=0.001) opt = torch.optim.Adam(model.parameters(), lr=0.001)
sched = torch.optim.lr_scheduler.StepLR(opt, step_size=25, gamma=0.25) sched = torch.optim.lr_scheduler.StepLR(opt, step_size=25, gamma=0.25)
...@@ -149,75 +211,97 @@ def train(proc_id, n_gpus, args, dataset, g, feats, paper_offset): ...@@ -149,75 +211,97 @@ def train(proc_id, n_gpus, args, dataset, g, feats, paper_offset):
with tqdm.tqdm(train_dataloader) as tq: with tqdm.tqdm(train_dataloader) as tq:
for i, (input_nodes, output_nodes, mfgs) in enumerate(tq): for i, (input_nodes, output_nodes, mfgs) in enumerate(tq):
mfgs = [g.to(dev_id) for g in mfgs] mfgs = [g.to(dev_id) for g in mfgs]
x = mfgs[0].srcdata['x'] x = mfgs[0].srcdata["x"]
y = mfgs[-1].dstdata['y'] y = mfgs[-1].dstdata["y"]
y_hat = model(mfgs, x) y_hat = model(mfgs, x)
loss = F.cross_entropy(y_hat, y) loss = F.cross_entropy(y_hat, y)
opt.zero_grad() opt.zero_grad()
loss.backward() loss.backward()
opt.step() opt.step()
acc = (y_hat.argmax(1) == y).float().mean() acc = (y_hat.argmax(1) == y).float().mean()
tq.set_postfix({'loss': '%.4f' % loss.item(), 'acc': '%.4f' % acc.item()}, refresh=False) tq.set_postfix(
{"loss": "%.4f" % loss.item(), "acc": "%.4f" % acc.item()},
refresh=False,
)
# eval in each process # eval in each process
model.eval() model.eval()
correct = torch.LongTensor([0]).to(dev_id) correct = torch.LongTensor([0]).to(dev_id)
total = torch.LongTensor([0]).to(dev_id) total = torch.LongTensor([0]).to(dev_id)
for i, (input_nodes, output_nodes, mfgs) in enumerate(tqdm.tqdm(valid_dataloader)): for i, (input_nodes, output_nodes, mfgs) in enumerate(
tqdm.tqdm(valid_dataloader)
):
with torch.no_grad(): with torch.no_grad():
mfgs = [g.to(dev_id) for g in mfgs] mfgs = [g.to(dev_id) for g in mfgs]
x = mfgs[0].srcdata['x'] x = mfgs[0].srcdata["x"]
y = mfgs[-1].dstdata['y'] y = mfgs[-1].dstdata["y"]
y_hat = model(mfgs, x) y_hat = model(mfgs, x)
correct += (y_hat.argmax(1) == y).sum().item() correct += (y_hat.argmax(1) == y).sum().item()
total += y_hat.shape[0] total += y_hat.shape[0]
# `reduce` data into process 0 # `reduce` data into process 0
torch.distributed.reduce(correct, dst=0, op=torch.distributed.ReduceOp.SUM) torch.distributed.reduce(
torch.distributed.reduce(total, dst=0, op=torch.distributed.ReduceOp.SUM) correct, dst=0, op=torch.distributed.ReduceOp.SUM
)
torch.distributed.reduce(
total, dst=0, op=torch.distributed.ReduceOp.SUM
)
acc = (correct / total).item() acc = (correct / total).item()
sched.step() sched.step()
# process 0 print accuracy and save model # process 0 print accuracy and save model
if proc_id == 0: if proc_id == 0:
print('Validation accuracy:', acc) print("Validation accuracy:", acc)
if best_acc < acc: if best_acc < acc:
best_acc = acc best_acc = acc
print('Updating best model...') print("Updating best model...")
torch.save(model.state_dict(), args.model_path) torch.save(model.state_dict(), args.model_path)
def test(args, dataset, g, feats, paper_offset): def test(args, dataset, g, feats, paper_offset):
print('Loading masks and labels...') print("Loading masks and labels...")
valid_idx = torch.LongTensor(dataset.get_idx_split('valid')) + paper_offset valid_idx = torch.LongTensor(dataset.get_idx_split("valid")) + paper_offset
test_idx = torch.LongTensor(dataset.get_idx_split('test')) + paper_offset test_idx = torch.LongTensor(dataset.get_idx_split("test")) + paper_offset
label = dataset.paper_label label = dataset.paper_label
print('Initializing data loader...') print("Initializing data loader...")
sampler = dgl.dataloading.MultiLayerNeighborSampler([160, 160]) sampler = dgl.dataloading.MultiLayerNeighborSampler([160, 160])
valid_collator = ExternalNodeCollator(g, valid_idx, sampler, paper_offset, feats, label) valid_collator = ExternalNodeCollator(
g, valid_idx, sampler, paper_offset, feats, label
)
valid_dataloader = torch.utils.data.DataLoader( valid_dataloader = torch.utils.data.DataLoader(
valid_collator.dataset, valid_collator.dataset,
batch_size=16, batch_size=16,
shuffle=False, shuffle=False,
drop_last=False, drop_last=False,
collate_fn=valid_collator.collate, collate_fn=valid_collator.collate,
num_workers=2 num_workers=2,
)
test_collator = ExternalNodeCollator(
g, test_idx, sampler, paper_offset, feats, label
) )
test_collator = ExternalNodeCollator(g, test_idx, sampler, paper_offset, feats, label)
test_dataloader = torch.utils.data.DataLoader( test_dataloader = torch.utils.data.DataLoader(
test_collator.dataset, test_collator.dataset,
batch_size=16, batch_size=16,
shuffle=False, shuffle=False,
drop_last=False, drop_last=False,
collate_fn=test_collator.collate, collate_fn=test_collator.collate,
num_workers=4 num_workers=4,
) )
print('Loading model...') print("Loading model...")
model = RGAT(dataset.num_paper_features, dataset.num_classes, 1024, 5, 2, 4, 0.5, 'paper').cuda() model = RGAT(
dataset.num_paper_features,
dataset.num_classes,
1024,
5,
2,
4,
0.5,
"paper",
).cuda()
# load ddp's model parameters, we need to remove the name of 'module.' # load ddp's model parameters, we need to remove the name of 'module.'
state_dict = torch.load(args.model_path) state_dict = torch.load(args.model_path)
...@@ -229,41 +313,73 @@ def test(args, dataset, g, feats, paper_offset): ...@@ -229,41 +313,73 @@ def test(args, dataset, g, feats, paper_offset):
model.eval() model.eval()
correct = total = 0 correct = total = 0
for i, (input_nodes, output_nodes, mfgs) in enumerate(tqdm.tqdm(valid_dataloader)): for i, (input_nodes, output_nodes, mfgs) in enumerate(
tqdm.tqdm(valid_dataloader)
):
with torch.no_grad(): with torch.no_grad():
mfgs = [g.to('cuda') for g in mfgs] mfgs = [g.to("cuda") for g in mfgs]
x = mfgs[0].srcdata['x'] x = mfgs[0].srcdata["x"]
y = mfgs[-1].dstdata['y'] y = mfgs[-1].dstdata["y"]
y_hat = model(mfgs, x) y_hat = model(mfgs, x)
correct += (y_hat.argmax(1) == y).sum().item() correct += (y_hat.argmax(1) == y).sum().item()
total += y_hat.shape[0] total += y_hat.shape[0]
acc = correct / total acc = correct / total
print('Validation accuracy:', acc) print("Validation accuracy:", acc)
evaluator = MAG240MEvaluator() evaluator = MAG240MEvaluator()
y_preds = [] y_preds = []
for i, (input_nodes, output_nodes, mfgs) in enumerate(tqdm.tqdm(test_dataloader)): for i, (input_nodes, output_nodes, mfgs) in enumerate(
tqdm.tqdm(test_dataloader)
):
with torch.no_grad(): with torch.no_grad():
mfgs = [g.to('cuda') for g in mfgs] mfgs = [g.to("cuda") for g in mfgs]
x = mfgs[0].srcdata['x'] x = mfgs[0].srcdata["x"]
y = mfgs[-1].dstdata['y'] y = mfgs[-1].dstdata["y"]
y_hat = model(mfgs, x) y_hat = model(mfgs, x)
y_preds.append(y_hat.argmax(1).cpu()) y_preds.append(y_hat.argmax(1).cpu())
evaluator.save_test_submission({'y_pred': torch.cat(y_preds)}, args.submission_path) evaluator.save_test_submission(
{"y_pred": torch.cat(y_preds)}, args.submission_path
)
if __name__ == '__main__': if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--rootdir', type=str, default='.', help='Directory to download the OGB dataset.') parser.add_argument(
parser.add_argument('--graph-path', type=str, default='./graph.dgl', help='Path to the graph.') "--rootdir",
parser.add_argument('--full-feature-path', type=str, default='./full.npy', type=str,
help='Path to the features of all nodes.') default=".",
parser.add_argument('--epochs', type=int, default=100, help='Number of epochs.') help="Directory to download the OGB dataset.",
parser.add_argument('--model-path', type=str, default='./model_ddp.pt', help='Path to store the best model.') )
parser.add_argument('--submission-path', type=str, default='./results_ddp', help='Submission directory.') parser.add_argument(
parser.add_argument('--gpus', type=str, default='0,1,2') "--graph-path",
type=str,
default="./graph.dgl",
help="Path to the graph.",
)
parser.add_argument(
"--full-feature-path",
type=str,
default="./full.npy",
help="Path to the features of all nodes.",
)
parser.add_argument(
"--epochs", type=int, default=100, help="Number of epochs."
)
parser.add_argument(
"--model-path",
type=str,
default="./model_ddp.pt",
help="Path to store the best model.",
)
parser.add_argument(
"--submission-path",
type=str,
default="./results_ddp",
help="Submission directory.",
)
parser.add_argument("--gpus", type=str, default="0,1,2")
args = parser.parse_args() args = parser.parse_args()
devices = list(map(int, args.gpus.split(','))) devices = list(map(int, args.gpus.split(",")))
n_gpus = len(devices) n_gpus = len(devices)
if n_gpus <= 1: if n_gpus <= 1:
...@@ -272,16 +388,25 @@ if __name__ == '__main__': ...@@ -272,16 +388,25 @@ if __name__ == '__main__':
dataset = MAG240MDataset(root=args.rootdir) dataset = MAG240MDataset(root=args.rootdir)
print('Loading graph') print("Loading graph")
(g,), _ = dgl.load_graphs(args.graph_path) (g,), _ = dgl.load_graphs(args.graph_path)
g = g.formats(['csc']) g = g.formats(["csc"])
print('Loading features') print("Loading features")
paper_offset = dataset.num_authors + dataset.num_institutions paper_offset = dataset.num_authors + dataset.num_institutions
num_nodes = paper_offset + dataset.num_papers num_nodes = paper_offset + dataset.num_papers
num_features = dataset.num_paper_features num_features = dataset.num_paper_features
feats = np.memmap(args.full_feature_path, mode='r', dtype='float16', shape=(num_nodes, num_features)) feats = np.memmap(
args.full_feature_path,
mode="r",
dtype="float16",
shape=(num_nodes, num_features),
)
mp.spawn(train, args=(n_gpus, args, dataset, g, feats, paper_offset), nprocs=n_gpus) mp.spawn(
train,
args=(n_gpus, args, dataset, g, feats, paper_offset),
nprocs=n_gpus,
)
test(args, dataset, g, feats, paper_offset) test(args, dataset, g, feats, paper_offset)
import dgl
import dgl.function as fn
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder
import dgl
import dgl.function as fn
from dgl.nn.pytorch import SumPooling from dgl.nn.pytorch import SumPooling
from ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder
### GIN convolution along the graph structure ### GIN convolution along the graph structure
class GINConv(nn.Module): class GINConv(nn.Module):
def __init__(self, emb_dim): def __init__(self, emb_dim):
''' """
emb_dim (int): node embedding dimensionality emb_dim (int): node embedding dimensionality
''' """
super(GINConv, self).__init__() super(GINConv, self).__init__()
self.mlp = nn.Sequential(nn.Linear(emb_dim, emb_dim), self.mlp = nn.Sequential(
nn.BatchNorm1d(emb_dim), nn.Linear(emb_dim, emb_dim),
nn.ReLU(), nn.BatchNorm1d(emb_dim),
nn.Linear(emb_dim, emb_dim)) nn.ReLU(),
nn.Linear(emb_dim, emb_dim),
)
self.eps = nn.Parameter(torch.Tensor([0])) self.eps = nn.Parameter(torch.Tensor([0]))
self.bond_encoder = BondEncoder(emb_dim = emb_dim) self.bond_encoder = BondEncoder(emb_dim=emb_dim)
def forward(self, g, x, edge_attr): def forward(self, g, x, edge_attr):
with g.local_scope(): with g.local_scope():
edge_embedding = self.bond_encoder(edge_attr) edge_embedding = self.bond_encoder(edge_attr)
g.ndata['x'] = x g.ndata["x"] = x
g.apply_edges(fn.copy_u('x', 'm')) g.apply_edges(fn.copy_u("x", "m"))
g.edata['m'] = F.relu(g.edata['m'] + edge_embedding) g.edata["m"] = F.relu(g.edata["m"] + edge_embedding)
g.update_all(fn.copy_e('m', 'm'), fn.sum('m', 'new_x')) g.update_all(fn.copy_e("m", "m"), fn.sum("m", "new_x"))
out = self.mlp((1 + self.eps) * x + g.ndata['new_x']) out = self.mlp((1 + self.eps) * x + g.ndata["new_x"])
return out return out
### GCN convolution along the graph structure ### GCN convolution along the graph structure
class GCNConv(nn.Module): class GCNConv(nn.Module):
def __init__(self, emb_dim): def __init__(self, emb_dim):
''' """
emb_dim (int): node embedding dimensionality emb_dim (int): node embedding dimensionality
''' """
super(GCNConv, self).__init__() super(GCNConv, self).__init__()
self.linear = nn.Linear(emb_dim, emb_dim) self.linear = nn.Linear(emb_dim, emb_dim)
self.root_emb = nn.Embedding(1, emb_dim) self.root_emb = nn.Embedding(1, emb_dim)
self.bond_encoder = BondEncoder(emb_dim = emb_dim) self.bond_encoder = BondEncoder(emb_dim=emb_dim)
def forward(self, g, x, edge_attr): def forward(self, g, x, edge_attr):
with g.local_scope(): with g.local_scope():
...@@ -56,29 +60,43 @@ class GCNConv(nn.Module): ...@@ -56,29 +60,43 @@ class GCNConv(nn.Module):
# Molecular graphs are undirected # Molecular graphs are undirected
# g.out_degrees() is the same as g.in_degrees() # g.out_degrees() is the same as g.in_degrees()
degs = (g.out_degrees().float() + 1).to(x.device) degs = (g.out_degrees().float() + 1).to(x.device)
norm = torch.pow(degs, -0.5).unsqueeze(-1) # (N, 1) norm = torch.pow(degs, -0.5).unsqueeze(-1) # (N, 1)
g.ndata['norm'] = norm g.ndata["norm"] = norm
g.apply_edges(fn.u_mul_v('norm', 'norm', 'norm')) g.apply_edges(fn.u_mul_v("norm", "norm", "norm"))
g.ndata['x'] = x g.ndata["x"] = x
g.apply_edges(fn.copy_u('x', 'm')) g.apply_edges(fn.copy_u("x", "m"))
g.edata['m'] = g.edata['norm'] * F.relu(g.edata['m'] + edge_embedding) g.edata["m"] = g.edata["norm"] * F.relu(
g.update_all(fn.copy_e('m', 'm'), fn.sum('m', 'new_x')) g.edata["m"] + edge_embedding
out = g.ndata['new_x'] + F.relu(x + self.root_emb.weight) * 1. / degs.view(-1, 1) )
g.update_all(fn.copy_e("m", "m"), fn.sum("m", "new_x"))
out = g.ndata["new_x"] + F.relu(
x + self.root_emb.weight
) * 1.0 / degs.view(-1, 1)
return out return out
### GNN to generate node embedding ### GNN to generate node embedding
class GNN_node(nn.Module): class GNN_node(nn.Module):
""" """
Output: Output:
node representations node representations
""" """
def __init__(self, num_layers, emb_dim, drop_ratio = 0.5, JK = "last", residual = False, gnn_type = 'gin'):
''' def __init__(
num_layers (int): number of GNN message passing layers self,
emb_dim (int): node embedding dimensionality num_layers,
''' emb_dim,
drop_ratio=0.5,
JK="last",
residual=False,
gnn_type="gin",
):
"""
num_layers (int): number of GNN message passing layers
emb_dim (int): node embedding dimensionality
"""
super(GNN_node, self).__init__() super(GNN_node, self).__init__()
self.num_layers = num_layers self.num_layers = num_layers
...@@ -97,12 +115,12 @@ class GNN_node(nn.Module): ...@@ -97,12 +115,12 @@ class GNN_node(nn.Module):
self.batch_norms = nn.ModuleList() self.batch_norms = nn.ModuleList()
for layer in range(num_layers): for layer in range(num_layers):
if gnn_type == 'gin': if gnn_type == "gin":
self.convs.append(GINConv(emb_dim)) self.convs.append(GINConv(emb_dim))
elif gnn_type == 'gcn': elif gnn_type == "gcn":
self.convs.append(GCNConv(emb_dim)) self.convs.append(GCNConv(emb_dim))
else: else:
ValueError('Undefined GNN type called {}'.format(gnn_type)) ValueError("Undefined GNN type called {}".format(gnn_type))
self.batch_norms.append(nn.BatchNorm1d(emb_dim)) self.batch_norms.append(nn.BatchNorm1d(emb_dim))
...@@ -115,10 +133,12 @@ class GNN_node(nn.Module): ...@@ -115,10 +133,12 @@ class GNN_node(nn.Module):
h = self.batch_norms[layer](h) h = self.batch_norms[layer](h)
if layer == self.num_layers - 1: if layer == self.num_layers - 1:
#remove relu for the last layer # remove relu for the last layer
h = F.dropout(h, self.drop_ratio, training = self.training) h = F.dropout(h, self.drop_ratio, training=self.training)
else: else:
h = F.dropout(F.relu(h), self.drop_ratio, training = self.training) h = F.dropout(
F.relu(h), self.drop_ratio, training=self.training
)
if self.residual: if self.residual:
h += h_list[layer] h += h_list[layer]
...@@ -142,11 +162,20 @@ class GNN_node_Virtualnode(nn.Module): ...@@ -142,11 +162,20 @@ class GNN_node_Virtualnode(nn.Module):
Output: Output:
node representations node representations
""" """
def __init__(self, num_layers, emb_dim, drop_ratio = 0.5, JK = "last", residual = False, gnn_type = 'gin'):
''' def __init__(
num_layers (int): number of GNN message passing layers self,
emb_dim (int): node embedding dimensionality num_layers,
''' emb_dim,
drop_ratio=0.5,
JK="last",
residual=False,
gnn_type="gin",
):
"""
num_layers (int): number of GNN message passing layers
emb_dim (int): node embedding dimensionality
"""
super(GNN_node_Virtualnode, self).__init__() super(GNN_node_Virtualnode, self).__init__()
self.num_layers = num_layers self.num_layers = num_layers
...@@ -173,31 +202,38 @@ class GNN_node_Virtualnode(nn.Module): ...@@ -173,31 +202,38 @@ class GNN_node_Virtualnode(nn.Module):
self.mlp_virtualnode_list = nn.ModuleList() self.mlp_virtualnode_list = nn.ModuleList()
for layer in range(num_layers): for layer in range(num_layers):
if gnn_type == 'gin': if gnn_type == "gin":
self.convs.append(GINConv(emb_dim)) self.convs.append(GINConv(emb_dim))
elif gnn_type == 'gcn': elif gnn_type == "gcn":
self.convs.append(GCNConv(emb_dim)) self.convs.append(GCNConv(emb_dim))
else: else:
ValueError('Undefined GNN type called {}'.format(gnn_type)) ValueError("Undefined GNN type called {}".format(gnn_type))
self.batch_norms.append(nn.BatchNorm1d(emb_dim)) self.batch_norms.append(nn.BatchNorm1d(emb_dim))
for layer in range(num_layers - 1): for layer in range(num_layers - 1):
self.mlp_virtualnode_list.append(nn.Sequential(nn.Linear(emb_dim, emb_dim), self.mlp_virtualnode_list.append(
nn.BatchNorm1d(emb_dim), nn.Sequential(
nn.ReLU(), nn.Linear(emb_dim, emb_dim),
nn.Linear(emb_dim, emb_dim), nn.BatchNorm1d(emb_dim),
nn.BatchNorm1d(emb_dim), nn.ReLU(),
nn.ReLU())) nn.Linear(emb_dim, emb_dim),
nn.BatchNorm1d(emb_dim),
nn.ReLU(),
)
)
self.pool = SumPooling() self.pool = SumPooling()
def forward(self, g, x, edge_attr): def forward(self, g, x, edge_attr):
### virtual node embeddings for graphs ### virtual node embeddings for graphs
virtualnode_embedding = self.virtualnode_embedding( virtualnode_embedding = self.virtualnode_embedding(
torch.zeros(g.batch_size).to(x.dtype).to(x.device)) torch.zeros(g.batch_size).to(x.dtype).to(x.device)
)
h_list = [self.atom_encoder(x)] h_list = [self.atom_encoder(x)]
batch_id = dgl.broadcast_nodes(g, torch.arange(g.batch_size).to(x.device)) batch_id = dgl.broadcast_nodes(
g, torch.arange(g.batch_size).to(x.device)
)
for layer in range(self.num_layers): for layer in range(self.num_layers):
### add message from virtual nodes to graph nodes ### add message from virtual nodes to graph nodes
h_list[layer] = h_list[layer] + virtualnode_embedding[batch_id] h_list[layer] = h_list[layer] + virtualnode_embedding[batch_id]
...@@ -206,10 +242,12 @@ class GNN_node_Virtualnode(nn.Module): ...@@ -206,10 +242,12 @@ class GNN_node_Virtualnode(nn.Module):
h = self.convs[layer](g, h_list[layer], edge_attr) h = self.convs[layer](g, h_list[layer], edge_attr)
h = self.batch_norms[layer](h) h = self.batch_norms[layer](h)
if layer == self.num_layers - 1: if layer == self.num_layers - 1:
#remove relu for the last layer # remove relu for the last layer
h = F.dropout(h, self.drop_ratio, training = self.training) h = F.dropout(h, self.drop_ratio, training=self.training)
else: else:
h = F.dropout(F.relu(h), self.drop_ratio, training = self.training) h = F.dropout(
F.relu(h), self.drop_ratio, training=self.training
)
if self.residual: if self.residual:
h = h + h_list[layer] h = h + h_list[layer]
...@@ -219,17 +257,26 @@ class GNN_node_Virtualnode(nn.Module): ...@@ -219,17 +257,26 @@ class GNN_node_Virtualnode(nn.Module):
### update the virtual nodes ### update the virtual nodes
if layer < self.num_layers - 1: if layer < self.num_layers - 1:
### add message from graph nodes to virtual nodes ### add message from graph nodes to virtual nodes
virtualnode_embedding_temp = self.pool(g, h_list[layer]) + virtualnode_embedding virtualnode_embedding_temp = (
self.pool(g, h_list[layer]) + virtualnode_embedding
)
### transform virtual nodes using MLP ### transform virtual nodes using MLP
virtualnode_embedding_temp = self.mlp_virtualnode_list[layer]( virtualnode_embedding_temp = self.mlp_virtualnode_list[layer](
virtualnode_embedding_temp) virtualnode_embedding_temp
)
if self.residual: if self.residual:
virtualnode_embedding = virtualnode_embedding + F.dropout( virtualnode_embedding = virtualnode_embedding + F.dropout(
virtualnode_embedding_temp, self.drop_ratio, training = self.training) virtualnode_embedding_temp,
self.drop_ratio,
training=self.training,
)
else: else:
virtualnode_embedding = F.dropout( virtualnode_embedding = F.dropout(
virtualnode_embedding_temp, self.drop_ratio, training = self.training) virtualnode_embedding_temp,
self.drop_ratio,
training=self.training,
)
### Different implementations of Jk-concat ### Different implementations of Jk-concat
if self.JK == "last": if self.JK == "last":
......
import torch import torch
import torch.nn as nn import torch.nn as nn
from conv import GNN_node, GNN_node_Virtualnode
from dgl.nn.pytorch import SumPooling, AvgPooling, MaxPooling, GlobalAttentionPooling, Set2Set from dgl.nn.pytorch import (
AvgPooling,
GlobalAttentionPooling,
MaxPooling,
Set2Set,
SumPooling,
)
from conv import GNN_node, GNN_node_Virtualnode
class GNN(nn.Module): class GNN(nn.Module):
def __init__(
def __init__(self, num_tasks = 1, num_layers = 5, emb_dim = 300, gnn_type = 'gin', self,
virtual_node = True, residual = False, drop_ratio = 0, JK = "last", num_tasks=1,
graph_pooling = "sum"): num_layers=5,
''' emb_dim=300,
num_tasks (int): number of labels to be predicted gnn_type="gin",
virtual_node (bool): whether to add virtual node or not virtual_node=True,
''' residual=False,
drop_ratio=0,
JK="last",
graph_pooling="sum",
):
"""
num_tasks (int): number of labels to be predicted
virtual_node (bool): whether to add virtual node or not
"""
super(GNN, self).__init__() super(GNN, self).__init__()
self.num_layers = num_layers self.num_layers = num_layers
...@@ -28,14 +42,23 @@ class GNN(nn.Module): ...@@ -28,14 +42,23 @@ class GNN(nn.Module):
### GNN to generate node embeddings ### GNN to generate node embeddings
if virtual_node: if virtual_node:
self.gnn_node = GNN_node_Virtualnode(num_layers, emb_dim, JK = JK, self.gnn_node = GNN_node_Virtualnode(
drop_ratio = drop_ratio, num_layers,
residual = residual, emb_dim,
gnn_type = gnn_type) JK=JK,
drop_ratio=drop_ratio,
residual=residual,
gnn_type=gnn_type,
)
else: else:
self.gnn_node = GNN_node(num_layers, emb_dim, JK = JK, drop_ratio = drop_ratio, self.gnn_node = GNN_node(
residual = residual, gnn_type = gnn_type) num_layers,
emb_dim,
JK=JK,
drop_ratio=drop_ratio,
residual=residual,
gnn_type=gnn_type,
)
### Pooling function to generate whole-graph embeddings ### Pooling function to generate whole-graph embeddings
if self.graph_pooling == "sum": if self.graph_pooling == "sum":
...@@ -46,18 +69,21 @@ class GNN(nn.Module): ...@@ -46,18 +69,21 @@ class GNN(nn.Module):
self.pool = MaxPooling self.pool = MaxPooling
elif self.graph_pooling == "attention": elif self.graph_pooling == "attention":
self.pool = GlobalAttentionPooling( self.pool = GlobalAttentionPooling(
gate_nn = nn.Sequential(nn.Linear(emb_dim, 2*emb_dim), gate_nn=nn.Sequential(
nn.BatchNorm1d(2*emb_dim), nn.Linear(emb_dim, 2 * emb_dim),
nn.ReLU(), nn.BatchNorm1d(2 * emb_dim),
nn.Linear(2*emb_dim, 1))) nn.ReLU(),
nn.Linear(2 * emb_dim, 1),
)
)
elif self.graph_pooling == "set2set": elif self.graph_pooling == "set2set":
self.pool = Set2Set(emb_dim, n_iters = 2, n_layers = 2) self.pool = Set2Set(emb_dim, n_iters=2, n_layers=2)
else: else:
raise ValueError("Invalid graph pooling type.") raise ValueError("Invalid graph pooling type.")
if graph_pooling == "set2set": if graph_pooling == "set2set":
self.graph_pred_linear = nn.Linear(2*self.emb_dim, self.num_tasks) self.graph_pred_linear = nn.Linear(2 * self.emb_dim, self.num_tasks)
else: else:
self.graph_pred_linear = nn.Linear(self.emb_dim, self.num_tasks) self.graph_pred_linear = nn.Linear(self.emb_dim, self.num_tasks)
......
import argparse import argparse
import dgl
import numpy as np
import os import os
import random import random
import numpy as np
import torch import torch
import torch.optim as optim import torch.optim as optim
from gnn import GNN
from ogb.lsc import DglPCQM4MDataset, PCQM4MEvaluator from ogb.lsc import DglPCQM4MDataset, PCQM4MEvaluator
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import StepLR
from tqdm import tqdm from tqdm import tqdm
from gnn import GNN import dgl
reg_criterion = torch.nn.L1Loss() reg_criterion = torch.nn.L1Loss()
...@@ -30,11 +31,13 @@ def train(model, device, loader, optimizer): ...@@ -30,11 +31,13 @@ def train(model, device, loader, optimizer):
for step, (bg, labels) in enumerate(tqdm(loader, desc="Iteration")): for step, (bg, labels) in enumerate(tqdm(loader, desc="Iteration")):
bg = bg.to(device) bg = bg.to(device)
x = bg.ndata.pop('feat') x = bg.ndata.pop("feat")
edge_attr = bg.edata.pop('feat') edge_attr = bg.edata.pop("feat")
labels = labels.to(device) labels = labels.to(device)
pred = model(bg, x, edge_attr).view(-1,) pred = model(bg, x, edge_attr).view(
-1,
)
optimizer.zero_grad() optimizer.zero_grad()
loss = reg_criterion(pred, labels) loss = reg_criterion(pred, labels)
loss.backward() loss.backward()
...@@ -52,12 +55,14 @@ def eval(model, device, loader, evaluator): ...@@ -52,12 +55,14 @@ def eval(model, device, loader, evaluator):
for step, (bg, labels) in enumerate(tqdm(loader, desc="Iteration")): for step, (bg, labels) in enumerate(tqdm(loader, desc="Iteration")):
bg = bg.to(device) bg = bg.to(device)
x = bg.ndata.pop('feat') x = bg.ndata.pop("feat")
edge_attr = bg.edata.pop('feat') edge_attr = bg.edata.pop("feat")
labels = labels.to(device) labels = labels.to(device)
with torch.no_grad(): with torch.no_grad():
pred = model(bg, x, edge_attr).view(-1, ) pred = model(bg, x, edge_attr).view(
-1,
)
y_true.append(labels.view(pred.shape).detach().cpu()) y_true.append(labels.view(pred.shape).detach().cpu())
y_pred.append(pred.detach().cpu()) y_pred.append(pred.detach().cpu())
...@@ -76,11 +81,13 @@ def test(model, device, loader): ...@@ -76,11 +81,13 @@ def test(model, device, loader):
for step, (bg, _) in enumerate(tqdm(loader, desc="Iteration")): for step, (bg, _) in enumerate(tqdm(loader, desc="Iteration")):
bg = bg.to(device) bg = bg.to(device)
x = bg.ndata.pop('feat') x = bg.ndata.pop("feat")
edge_attr = bg.edata.pop('feat') edge_attr = bg.edata.pop("feat")
with torch.no_grad(): with torch.no_grad():
pred = model(bg, x, edge_attr).view(-1, ) pred = model(bg, x, edge_attr).view(
-1,
)
y_pred.append(pred.detach().cpu()) y_pred.append(pred.detach().cpu())
...@@ -91,37 +98,88 @@ def test(model, device, loader): ...@@ -91,37 +98,88 @@ def test(model, device, loader):
def main(): def main():
# Training settings # Training settings
parser = argparse.ArgumentParser(description='GNN baselines on pcqm4m with DGL') parser = argparse.ArgumentParser(
parser.add_argument('--seed', type=int, default=42, description="GNN baselines on pcqm4m with DGL"
help='random seed to use (default: 42)') )
parser.add_argument('--device', type=int, default=0, parser.add_argument(
help='which gpu to use if any (default: 0)') "--seed", type=int, default=42, help="random seed to use (default: 42)"
parser.add_argument('--gnn', type=str, default='gin-virtual', )
help='GNN to use, which can be from ' parser.add_argument(
'[gin, gin-virtual, gcn, gcn-virtual] (default: gin-virtual)') "--device",
parser.add_argument('--graph_pooling', type=str, default='sum', type=int,
help='graph pooling strategy mean or sum (default: sum)') default=0,
parser.add_argument('--drop_ratio', type=float, default=0, help="which gpu to use if any (default: 0)",
help='dropout ratio (default: 0)') )
parser.add_argument('--num_layers', type=int, default=5, parser.add_argument(
help='number of GNN message passing layers (default: 5)') "--gnn",
parser.add_argument('--emb_dim', type=int, default=600, type=str,
help='dimensionality of hidden units in GNNs (default: 600)') default="gin-virtual",
parser.add_argument('--train_subset', action='store_true', help="GNN to use, which can be from "
help='use 10% of the training set for training') "[gin, gin-virtual, gcn, gcn-virtual] (default: gin-virtual)",
parser.add_argument('--batch_size', type=int, default=256, )
help='input batch size for training (default: 256)') parser.add_argument(
parser.add_argument('--epochs', type=int, default=100, "--graph_pooling",
help='number of epochs to train (default: 100)') type=str,
parser.add_argument('--num_workers', type=int, default=0, default="sum",
help='number of workers (default: 0)') help="graph pooling strategy mean or sum (default: sum)",
parser.add_argument('--log_dir', type=str, default="", )
help='tensorboard log directory. If not specified, ' parser.add_argument(
'tensorboard will not be used.') "--drop_ratio", type=float, default=0, help="dropout ratio (default: 0)"
parser.add_argument('--checkpoint_dir', type=str, default='', )
help='directory to save checkpoint') parser.add_argument(
parser.add_argument('--save_test_dir', type=str, default='', "--num_layers",
help='directory to save test submission file') type=int,
default=5,
help="number of GNN message passing layers (default: 5)",
)
parser.add_argument(
"--emb_dim",
type=int,
default=600,
help="dimensionality of hidden units in GNNs (default: 600)",
)
parser.add_argument(
"--train_subset",
action="store_true",
help="use 10% of the training set for training",
)
parser.add_argument(
"--batch_size",
type=int,
default=256,
help="input batch size for training (default: 256)",
)
parser.add_argument(
"--epochs",
type=int,
default=100,
help="number of epochs to train (default: 100)",
)
parser.add_argument(
"--num_workers",
type=int,
default=0,
help="number of workers (default: 0)",
)
parser.add_argument(
"--log_dir",
type=str,
default="",
help="tensorboard log directory. If not specified, "
"tensorboard will not be used.",
)
parser.add_argument(
"--checkpoint_dir",
type=str,
default="",
help="directory to save checkpoint",
)
parser.add_argument(
"--save_test_dir",
type=str,
default="",
help="directory to save test submission file",
)
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
...@@ -137,7 +195,7 @@ def main(): ...@@ -137,7 +195,7 @@ def main():
device = torch.device("cpu") device = torch.device("cpu")
### automatic dataloading and splitting ### automatic dataloading and splitting
dataset = DglPCQM4MDataset(root='dataset/') dataset = DglPCQM4MDataset(root="dataset/")
# split_idx['train'], split_idx['valid'], split_idx['test'] # split_idx['train'], split_idx['valid'], split_idx['test']
# separately gives a 1D int64 tensor # separately gives a 1D int64 tensor
...@@ -148,47 +206,77 @@ def main(): ...@@ -148,47 +206,77 @@ def main():
if args.train_subset: if args.train_subset:
subset_ratio = 0.1 subset_ratio = 0.1
subset_idx = torch.randperm(len(split_idx["train"]))[:int(subset_ratio * len(split_idx["train"]))] subset_idx = torch.randperm(len(split_idx["train"]))[
train_loader = DataLoader(dataset[split_idx["train"][subset_idx]], batch_size=args.batch_size, shuffle=True, : int(subset_ratio * len(split_idx["train"]))
num_workers=args.num_workers, collate_fn=collate_dgl) ]
train_loader = DataLoader(
dataset[split_idx["train"][subset_idx]],
batch_size=args.batch_size,
shuffle=True,
num_workers=args.num_workers,
collate_fn=collate_dgl,
)
else: else:
train_loader = DataLoader(dataset[split_idx["train"]], batch_size=args.batch_size, shuffle=True, train_loader = DataLoader(
num_workers=args.num_workers, collate_fn=collate_dgl) dataset[split_idx["train"]],
batch_size=args.batch_size,
valid_loader = DataLoader(dataset[split_idx["valid"]], batch_size=args.batch_size, shuffle=False, shuffle=True,
num_workers=args.num_workers, collate_fn=collate_dgl) num_workers=args.num_workers,
collate_fn=collate_dgl,
if args.save_test_dir != '': )
test_loader = DataLoader(dataset[split_idx["test"]], batch_size=args.batch_size, shuffle=False,
num_workers=args.num_workers, collate_fn=collate_dgl) valid_loader = DataLoader(
dataset[split_idx["valid"]],
if args.checkpoint_dir != '': batch_size=args.batch_size,
shuffle=False,
num_workers=args.num_workers,
collate_fn=collate_dgl,
)
if args.save_test_dir != "":
test_loader = DataLoader(
dataset[split_idx["test"]],
batch_size=args.batch_size,
shuffle=False,
num_workers=args.num_workers,
collate_fn=collate_dgl,
)
if args.checkpoint_dir != "":
os.makedirs(args.checkpoint_dir, exist_ok=True) os.makedirs(args.checkpoint_dir, exist_ok=True)
shared_params = { shared_params = {
'num_layers': args.num_layers, "num_layers": args.num_layers,
'emb_dim': args.emb_dim, "emb_dim": args.emb_dim,
'drop_ratio': args.drop_ratio, "drop_ratio": args.drop_ratio,
'graph_pooling': args.graph_pooling "graph_pooling": args.graph_pooling,
} }
if args.gnn == 'gin': if args.gnn == "gin":
model = GNN(gnn_type='gin', virtual_node=False, **shared_params).to(device) model = GNN(gnn_type="gin", virtual_node=False, **shared_params).to(
elif args.gnn == 'gin-virtual': device
model = GNN(gnn_type='gin', virtual_node=True, **shared_params).to(device) )
elif args.gnn == 'gcn': elif args.gnn == "gin-virtual":
model = GNN(gnn_type='gcn', virtual_node=False, **shared_params).to(device) model = GNN(gnn_type="gin", virtual_node=True, **shared_params).to(
elif args.gnn == 'gcn-virtual': device
model = GNN(gnn_type='gcn', virtual_node=True, **shared_params).to(device) )
elif args.gnn == "gcn":
model = GNN(gnn_type="gcn", virtual_node=False, **shared_params).to(
device
)
elif args.gnn == "gcn-virtual":
model = GNN(gnn_type="gcn", virtual_node=True, **shared_params).to(
device
)
else: else:
raise ValueError('Invalid GNN type') raise ValueError("Invalid GNN type")
num_params = sum(p.numel() for p in model.parameters()) num_params = sum(p.numel() for p in model.parameters())
print(f'#Params: {num_params}') print(f"#Params: {num_params}")
optimizer = optim.Adam(model.parameters(), lr=0.001) optimizer = optim.Adam(model.parameters(), lr=0.001)
if args.log_dir != '': if args.log_dir != "":
writer = SummaryWriter(log_dir=args.log_dir) writer = SummaryWriter(log_dir=args.log_dir)
best_valid_mae = 1000 best_valid_mae = 1000
...@@ -201,40 +289,50 @@ def main(): ...@@ -201,40 +289,50 @@ def main():
for epoch in range(1, args.epochs + 1): for epoch in range(1, args.epochs + 1):
print("=====Epoch {}".format(epoch)) print("=====Epoch {}".format(epoch))
print('Training...') print("Training...")
train_mae = train(model, device, train_loader, optimizer) train_mae = train(model, device, train_loader, optimizer)
print('Evaluating...') print("Evaluating...")
valid_mae = eval(model, device, valid_loader, evaluator) valid_mae = eval(model, device, valid_loader, evaluator)
print({'Train': train_mae, 'Validation': valid_mae}) print({"Train": train_mae, "Validation": valid_mae})
if args.log_dir != '': if args.log_dir != "":
writer.add_scalar('valid/mae', valid_mae, epoch) writer.add_scalar("valid/mae", valid_mae, epoch)
writer.add_scalar('train/mae', train_mae, epoch) writer.add_scalar("train/mae", train_mae, epoch)
if valid_mae < best_valid_mae: if valid_mae < best_valid_mae:
best_valid_mae = valid_mae best_valid_mae = valid_mae
if args.checkpoint_dir != '': if args.checkpoint_dir != "":
print('Saving checkpoint...') print("Saving checkpoint...")
checkpoint = {'epoch': epoch, 'model_state_dict': model.state_dict(), checkpoint = {
'optimizer_state_dict': optimizer.state_dict(), "epoch": epoch,
'scheduler_state_dict': scheduler.state_dict(), 'best_val_mae': best_valid_mae, "model_state_dict": model.state_dict(),
'num_params': num_params} "optimizer_state_dict": optimizer.state_dict(),
torch.save(checkpoint, os.path.join(args.checkpoint_dir, 'checkpoint.pt')) "scheduler_state_dict": scheduler.state_dict(),
"best_val_mae": best_valid_mae,
if args.save_test_dir != '': "num_params": num_params,
print('Predicting on test data...') }
torch.save(
checkpoint,
os.path.join(args.checkpoint_dir, "checkpoint.pt"),
)
if args.save_test_dir != "":
print("Predicting on test data...")
y_pred = test(model, device, test_loader) y_pred = test(model, device, test_loader)
print('Saving test submission file...') print("Saving test submission file...")
evaluator.save_test_submission({'y_pred': y_pred}, args.save_test_dir) evaluator.save_test_submission(
{"y_pred": y_pred}, args.save_test_dir
)
scheduler.step() scheduler.step()
print(f'Best validation MAE so far: {best_valid_mae}') print(f"Best validation MAE so far: {best_valid_mae}")
if args.log_dir != '': if args.log_dir != "":
writer.close() writer.close()
if __name__ == "__main__": if __name__ == "__main__":
main() main()
import argparse import argparse
import dgl
import numpy as np
import os import os
import random import random
import torch
import numpy as np
import torch
from gnn import GNN
from ogb.lsc import PCQM4MDataset, PCQM4MEvaluator from ogb.lsc import PCQM4MDataset, PCQM4MEvaluator
from ogb.utils import smiles2graph from ogb.utils import smiles2graph
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from tqdm import tqdm from tqdm import tqdm
from gnn import GNN import dgl
def collate_dgl(graphs): def collate_dgl(graphs):
batched_graph = dgl.batch(graphs) batched_graph = dgl.batch(graphs)
return batched_graph return batched_graph
def test(model, device, loader): def test(model, device, loader):
model.eval() model.eval()
y_pred = [] y_pred = []
for step, bg in enumerate(tqdm(loader, desc="Iteration")): for step, bg in enumerate(tqdm(loader, desc="Iteration")):
bg = bg.to(device) bg = bg.to(device)
x = bg.ndata.pop('feat') x = bg.ndata.pop("feat")
edge_attr = bg.edata.pop('feat') edge_attr = bg.edata.pop("feat")
with torch.no_grad(): with torch.no_grad():
pred = model(bg, x, edge_attr).view(-1, ) pred = model(bg, x, edge_attr).view(
-1,
)
y_pred.append(pred.detach().cpu()) y_pred.append(pred.detach().cpu())
...@@ -43,53 +47,99 @@ class OnTheFlyPCQMDataset(object): ...@@ -43,53 +47,99 @@ class OnTheFlyPCQMDataset(object):
self.smiles2graph = smiles2graph self.smiles2graph = smiles2graph
def __getitem__(self, idx): def __getitem__(self, idx):
'''Get datapoint with index''' """Get datapoint with index"""
smiles, _ = self.smiles_list[idx] smiles, _ = self.smiles_list[idx]
graph = self.smiles2graph(smiles) graph = self.smiles2graph(smiles)
dgl_graph = dgl.graph((graph['edge_index'][0], graph['edge_index'][1]), dgl_graph = dgl.graph(
num_nodes=graph['num_nodes']) (graph["edge_index"][0], graph["edge_index"][1]),
dgl_graph.edata['feat'] = torch.from_numpy(graph['edge_feat']).to(torch.int64) num_nodes=graph["num_nodes"],
dgl_graph.ndata['feat'] = torch.from_numpy(graph['node_feat']).to(torch.int64) )
dgl_graph.edata["feat"] = torch.from_numpy(graph["edge_feat"]).to(
torch.int64
)
dgl_graph.ndata["feat"] = torch.from_numpy(graph["node_feat"]).to(
torch.int64
)
return dgl_graph return dgl_graph
def __len__(self): def __len__(self):
'''Length of the dataset """Length of the dataset
Returns Returns
------- -------
int int
Length of Dataset Length of Dataset
''' """
return len(self.smiles_list) return len(self.smiles_list)
def main(): def main():
# Training settings # Training settings
parser = argparse.ArgumentParser(description='GNN baselines on pcqm4m with DGL') parser = argparse.ArgumentParser(
parser.add_argument('--seed', type=int, default=42, description="GNN baselines on pcqm4m with DGL"
help='random seed to use (default: 42)') )
parser.add_argument('--device', type=int, default=0, parser.add_argument(
help='which gpu to use if any (default: 0)') "--seed", type=int, default=42, help="random seed to use (default: 42)"
parser.add_argument('--gnn', type=str, default='gin-virtual', )
help='GNN to use, which can be from ' parser.add_argument(
'[gin, gin-virtual, gcn, gcn-virtual] (default: gin-virtual)') "--device",
parser.add_argument('--graph_pooling', type=str, default='sum', type=int,
help='graph pooling strategy mean or sum (default: sum)') default=0,
parser.add_argument('--drop_ratio', type=float, default=0, help="which gpu to use if any (default: 0)",
help='dropout ratio (default: 0)') )
parser.add_argument('--num_layers', type=int, default=5, parser.add_argument(
help='number of GNN message passing layers (default: 5)') "--gnn",
parser.add_argument('--emb_dim', type=int, default=600, type=str,
help='dimensionality of hidden units in GNNs (default: 600)') default="gin-virtual",
parser.add_argument('--batch_size', type=int, default=256, help="GNN to use, which can be from "
help='input batch size for training (default: 256)') "[gin, gin-virtual, gcn, gcn-virtual] (default: gin-virtual)",
parser.add_argument('--num_workers', type=int, default=0, )
help='number of workers (default: 0)') parser.add_argument(
parser.add_argument('--checkpoint_dir', type=str, default='', "--graph_pooling",
help='directory to save checkpoint') type=str,
parser.add_argument('--save_test_dir', type=str, default='', default="sum",
help='directory to save test submission file') help="graph pooling strategy mean or sum (default: sum)",
)
parser.add_argument(
"--drop_ratio", type=float, default=0, help="dropout ratio (default: 0)"
)
parser.add_argument(
"--num_layers",
type=int,
default=5,
help="number of GNN message passing layers (default: 5)",
)
parser.add_argument(
"--emb_dim",
type=int,
default=600,
help="dimensionality of hidden units in GNNs (default: 600)",
)
parser.add_argument(
"--batch_size",
type=int,
default=256,
help="input batch size for training (default: 256)",
)
parser.add_argument(
"--num_workers",
type=int,
default=0,
help="number of workers (default: 0)",
)
parser.add_argument(
"--checkpoint_dir",
type=str,
default="",
help="directory to save checkpoint",
)
parser.add_argument(
"--save_test_dir",
type=str,
default="",
help="directory to save test submission file",
)
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
...@@ -106,50 +156,63 @@ def main(): ...@@ -106,50 +156,63 @@ def main():
### automatic data loading and splitting ### automatic data loading and splitting
### Read in the raw SMILES strings ### Read in the raw SMILES strings
smiles_dataset = PCQM4MDataset(root='dataset/', only_smiles=True) smiles_dataset = PCQM4MDataset(root="dataset/", only_smiles=True)
split_idx = smiles_dataset.get_idx_split() split_idx = smiles_dataset.get_idx_split()
test_smiles_dataset = [smiles_dataset[i] for i in split_idx['test']] test_smiles_dataset = [smiles_dataset[i] for i in split_idx["test"]]
onthefly_dataset = OnTheFlyPCQMDataset(test_smiles_dataset) onthefly_dataset = OnTheFlyPCQMDataset(test_smiles_dataset)
test_loader = DataLoader(onthefly_dataset, batch_size=args.batch_size, shuffle=False, test_loader = DataLoader(
num_workers=args.num_workers, collate_fn=collate_dgl) onthefly_dataset,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.num_workers,
collate_fn=collate_dgl,
)
### automatic evaluator. ### automatic evaluator.
evaluator = PCQM4MEvaluator() evaluator = PCQM4MEvaluator()
shared_params = { shared_params = {
'num_layers': args.num_layers, "num_layers": args.num_layers,
'emb_dim': args.emb_dim, "emb_dim": args.emb_dim,
'drop_ratio': args.drop_ratio, "drop_ratio": args.drop_ratio,
'graph_pooling': args.graph_pooling "graph_pooling": args.graph_pooling,
} }
if args.gnn == 'gin': if args.gnn == "gin":
model = GNN(gnn_type='gin', virtual_node=False, **shared_params).to(device) model = GNN(gnn_type="gin", virtual_node=False, **shared_params).to(
elif args.gnn == 'gin-virtual': device
model = GNN(gnn_type='gin', virtual_node=True, **shared_params).to(device) )
elif args.gnn == 'gcn': elif args.gnn == "gin-virtual":
model = GNN(gnn_type='gcn', virtual_node=False, **shared_params).to(device) model = GNN(gnn_type="gin", virtual_node=True, **shared_params).to(
elif args.gnn == 'gcn-virtual': device
model = GNN(gnn_type='gcn', virtual_node=True, **shared_params).to(device) )
elif args.gnn == "gcn":
model = GNN(gnn_type="gcn", virtual_node=False, **shared_params).to(
device
)
elif args.gnn == "gcn-virtual":
model = GNN(gnn_type="gcn", virtual_node=True, **shared_params).to(
device
)
else: else:
raise ValueError('Invalid GNN type') raise ValueError("Invalid GNN type")
num_params = sum(p.numel() for p in model.parameters()) num_params = sum(p.numel() for p in model.parameters())
print(f'#Params: {num_params}') print(f"#Params: {num_params}")
checkpoint_path = os.path.join(args.checkpoint_dir, 'checkpoint.pt') checkpoint_path = os.path.join(args.checkpoint_dir, "checkpoint.pt")
if not os.path.exists(checkpoint_path): if not os.path.exists(checkpoint_path):
raise RuntimeError(f'Checkpoint file not found at {checkpoint_path}') raise RuntimeError(f"Checkpoint file not found at {checkpoint_path}")
## reading in checkpoint ## reading in checkpoint
checkpoint = torch.load(checkpoint_path) checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint['model_state_dict']) model.load_state_dict(checkpoint["model_state_dict"])
print('Predicting on test data...') print("Predicting on test data...")
y_pred = test(model, device, test_loader) y_pred = test(model, device, test_loader)
print('Saving test submission file...') print("Saving test submission file...")
evaluator.save_test_submission({'y_pred': y_pred}, args.save_test_dir) evaluator.save_test_submission({"y_pred": y_pred}, args.save_test_dir)
if __name__ == "__main__": if __name__ == "__main__":
......
"""Graph builder from pandas dataframes""" """Graph builder from pandas dataframes"""
from collections import namedtuple from collections import namedtuple
from pandas.api.types import is_numeric_dtype, is_categorical_dtype, is_categorical
from pandas.api.types import (
is_categorical,
is_categorical_dtype,
is_numeric_dtype,
)
import dgl import dgl
__all__ = ['PandasGraphBuilder'] __all__ = ["PandasGraphBuilder"]
def _series_to_tensor(series): def _series_to_tensor(series):
if is_categorical(series): if is_categorical(series):
return torch.LongTensor(series.cat.codes.values.astype('int64')) return torch.LongTensor(series.cat.codes.values.astype("int64"))
else: # numeric else: # numeric
return torch.FloatTensor(series.values) return torch.FloatTensor(series.values)
class PandasGraphBuilder(object): class PandasGraphBuilder(object):
"""Creates a heterogeneous graph from multiple pandas dataframes. """Creates a heterogeneous graph from multiple pandas dataframes.
...@@ -60,25 +68,36 @@ class PandasGraphBuilder(object): ...@@ -60,25 +68,36 @@ class PandasGraphBuilder(object):
>>> g.num_edges('plays') >>> g.num_edges('plays')
4 4
""" """
def __init__(self): def __init__(self):
self.entity_tables = {} self.entity_tables = {}
self.relation_tables = {} self.relation_tables = {}
self.entity_pk_to_name = {} # mapping from primary key name to entity name self.entity_pk_to_name = (
self.entity_pk = {} # mapping from entity name to primary key {}
self.entity_key_map = {} # mapping from entity names to primary key values ) # mapping from primary key name to entity name
self.entity_pk = {} # mapping from entity name to primary key
self.entity_key_map = (
{}
) # mapping from entity names to primary key values
self.num_nodes_per_type = {} self.num_nodes_per_type = {}
self.edges_per_relation = {} self.edges_per_relation = {}
self.relation_name_to_etype = {} self.relation_name_to_etype = {}
self.relation_src_key = {} # mapping from relation name to source key self.relation_src_key = {} # mapping from relation name to source key
self.relation_dst_key = {} # mapping from relation name to destination key self.relation_dst_key = (
{}
) # mapping from relation name to destination key
def add_entities(self, entity_table, primary_key, name): def add_entities(self, entity_table, primary_key, name):
entities = entity_table[primary_key].astype('category') entities = entity_table[primary_key].astype("category")
if not (entities.value_counts() == 1).all(): if not (entities.value_counts() == 1).all():
raise ValueError('Different entity with the same primary key detected.') raise ValueError(
"Different entity with the same primary key detected."
)
# preserve the category order in the original entity table # preserve the category order in the original entity table
entities = entities.cat.reorder_categories(entity_table[primary_key].values) entities = entities.cat.reorder_categories(
entity_table[primary_key].values
)
self.entity_pk_to_name[primary_key] = name self.entity_pk_to_name[primary_key] = name
self.entity_pk[name] = primary_key self.entity_pk[name] = primary_key
...@@ -86,33 +105,47 @@ class PandasGraphBuilder(object): ...@@ -86,33 +105,47 @@ class PandasGraphBuilder(object):
self.entity_key_map[name] = entities self.entity_key_map[name] = entities
self.entity_tables[name] = entity_table self.entity_tables[name] = entity_table
def add_binary_relations(self, relation_table, source_key, destination_key, name): def add_binary_relations(
src = relation_table[source_key].astype('category') self, relation_table, source_key, destination_key, name
):
src = relation_table[source_key].astype("category")
src = src.cat.set_categories( src = src.cat.set_categories(
self.entity_key_map[self.entity_pk_to_name[source_key]].cat.categories) self.entity_key_map[
dst = relation_table[destination_key].astype('category') self.entity_pk_to_name[source_key]
].cat.categories
)
dst = relation_table[destination_key].astype("category")
dst = dst.cat.set_categories( dst = dst.cat.set_categories(
self.entity_key_map[self.entity_pk_to_name[destination_key]].cat.categories) self.entity_key_map[
self.entity_pk_to_name[destination_key]
].cat.categories
)
if src.isnull().any(): if src.isnull().any():
raise ValueError( raise ValueError(
'Some source entities in relation %s do not exist in entity %s.' % "Some source entities in relation %s do not exist in entity %s."
(name, source_key)) % (name, source_key)
)
if dst.isnull().any(): if dst.isnull().any():
raise ValueError( raise ValueError(
'Some destination entities in relation %s do not exist in entity %s.' % "Some destination entities in relation %s do not exist in entity %s."
(name, destination_key)) % (name, destination_key)
)
srctype = self.entity_pk_to_name[source_key] srctype = self.entity_pk_to_name[source_key]
dsttype = self.entity_pk_to_name[destination_key] dsttype = self.entity_pk_to_name[destination_key]
etype = (srctype, name, dsttype) etype = (srctype, name, dsttype)
self.relation_name_to_etype[name] = etype self.relation_name_to_etype[name] = etype
self.edges_per_relation[etype] = (src.cat.codes.values.astype('int64'), dst.cat.codes.values.astype('int64')) self.edges_per_relation[etype] = (
src.cat.codes.values.astype("int64"),
dst.cat.codes.values.astype("int64"),
)
self.relation_tables[name] = relation_table self.relation_tables[name] = relation_table
self.relation_src_key[name] = source_key self.relation_src_key[name] = source_key
self.relation_dst_key[name] = destination_key self.relation_dst_key[name] = destination_key
def build(self): def build(self):
# Create heterograph # Create heterograph
graph = dgl.heterograph(self.edges_per_relation, self.num_nodes_per_type) graph = dgl.heterograph(
self.edges_per_relation, self.num_nodes_per_type
)
return graph return graph
import torch import dask.dataframe as dd
import dgl
import numpy as np import numpy as np
import scipy.sparse as ssp import scipy.sparse as ssp
import torch
import tqdm import tqdm
import dask.dataframe as dd
import dgl
# This is the train-test split method most of the recommender system papers running on MovieLens # This is the train-test split method most of the recommender system papers running on MovieLens
# takes. It essentially follows the intuition of "training on the past and predict the future". # takes. It essentially follows the intuition of "training on the past and predict the future".
# One can also change the threshold to make validation and test set take larger proportions. # One can also change the threshold to make validation and test set take larger proportions.
def train_test_split_by_time(df, timestamp, user): def train_test_split_by_time(df, timestamp, user):
df['train_mask'] = np.ones((len(df),), dtype=np.bool) df["train_mask"] = np.ones((len(df),), dtype=np.bool)
df['val_mask'] = np.zeros((len(df),), dtype=np.bool) df["val_mask"] = np.zeros((len(df),), dtype=np.bool)
df['test_mask'] = np.zeros((len(df),), dtype=np.bool) df["test_mask"] = np.zeros((len(df),), dtype=np.bool)
df = dd.from_pandas(df, npartitions=10) df = dd.from_pandas(df, npartitions=10)
def train_test_split(df): def train_test_split(df):
df = df.sort_values([timestamp]) df = df.sort_values([timestamp])
if df.shape[0] > 1: if df.shape[0] > 1:
...@@ -22,16 +25,25 @@ def train_test_split_by_time(df, timestamp, user): ...@@ -22,16 +25,25 @@ def train_test_split_by_time(df, timestamp, user):
df.iloc[-2, -3] = False df.iloc[-2, -3] = False
df.iloc[-2, -2] = True df.iloc[-2, -2] = True
return df return df
df = df.groupby(user, group_keys=False).apply(train_test_split).compute(scheduler='processes').sort_index()
df = (
df.groupby(user, group_keys=False)
.apply(train_test_split)
.compute(scheduler="processes")
.sort_index()
)
print(df[df[user] == df[user].unique()[0]].sort_values(timestamp)) print(df[df[user] == df[user].unique()[0]].sort_values(timestamp))
return df['train_mask'].to_numpy().nonzero()[0], \ return (
df['val_mask'].to_numpy().nonzero()[0], \ df["train_mask"].to_numpy().nonzero()[0],
df['test_mask'].to_numpy().nonzero()[0] df["val_mask"].to_numpy().nonzero()[0],
df["test_mask"].to_numpy().nonzero()[0],
)
def build_train_graph(g, train_indices, utype, itype, etype, etype_rev): def build_train_graph(g, train_indices, utype, itype, etype, etype_rev):
train_g = g.edge_subgraph( train_g = g.edge_subgraph(
{etype: train_indices, etype_rev: train_indices}, {etype: train_indices, etype_rev: train_indices}, relabel_nodes=False
relabel_nodes=False) )
# copy features # copy features
for ntype in g.ntypes: for ntype in g.ntypes:
...@@ -39,10 +51,13 @@ def build_train_graph(g, train_indices, utype, itype, etype, etype_rev): ...@@ -39,10 +51,13 @@ def build_train_graph(g, train_indices, utype, itype, etype, etype_rev):
train_g.nodes[ntype].data[col] = data train_g.nodes[ntype].data[col] = data
for etype in g.etypes: for etype in g.etypes:
for col, data in g.edges[etype].data.items(): for col, data in g.edges[etype].data.items():
train_g.edges[etype].data[col] = data[train_g.edges[etype].data[dgl.EID]] train_g.edges[etype].data[col] = data[
train_g.edges[etype].data[dgl.EID]
]
return train_g return train_g
def build_val_test_matrix(g, val_indices, test_indices, utype, itype, etype): def build_val_test_matrix(g, val_indices, test_indices, utype, itype, etype):
n_users = g.num_nodes(utype) n_users = g.num_nodes(utype)
n_items = g.num_nodes(itype) n_items = g.num_nodes(itype)
...@@ -52,11 +67,17 @@ def build_val_test_matrix(g, val_indices, test_indices, utype, itype, etype): ...@@ -52,11 +67,17 @@ def build_val_test_matrix(g, val_indices, test_indices, utype, itype, etype):
val_dst = val_dst.numpy() val_dst = val_dst.numpy()
test_src = test_src.numpy() test_src = test_src.numpy()
test_dst = test_dst.numpy() test_dst = test_dst.numpy()
val_matrix = ssp.coo_matrix((np.ones_like(val_src), (val_src, val_dst)), (n_users, n_items)) val_matrix = ssp.coo_matrix(
test_matrix = ssp.coo_matrix((np.ones_like(test_src), (test_src, test_dst)), (n_users, n_items)) (np.ones_like(val_src), (val_src, val_dst)), (n_users, n_items)
)
test_matrix = ssp.coo_matrix(
(np.ones_like(test_src), (test_src, test_dst)), (n_users, n_items)
)
return val_matrix, test_matrix return val_matrix, test_matrix
def linear_normalize(values): def linear_normalize(values):
return (values - values.min(0, keepdims=True)) / \ return (values - values.min(0, keepdims=True)) / (
(values.max(0, keepdims=True) - values.min(0, keepdims=True)) values.max(0, keepdims=True) - values.min(0, keepdims=True)
)
import argparse
import pickle
import numpy as np import numpy as np
import torch import torch
import pickle
import dgl import dgl
import argparse
def prec(recommendations, ground_truth): def prec(recommendations, ground_truth):
n_users, n_items = ground_truth.shape n_users, n_items = ground_truth.shape
...@@ -13,8 +16,11 @@ def prec(recommendations, ground_truth): ...@@ -13,8 +16,11 @@ def prec(recommendations, ground_truth):
hit = relevance.any(axis=1).mean() hit = relevance.any(axis=1).mean()
return hit return hit
class LatestNNRecommender(object): class LatestNNRecommender(object):
def __init__(self, user_ntype, item_ntype, user_to_item_etype, timestamp, batch_size): def __init__(
self, user_ntype, item_ntype, user_to_item_etype, timestamp, batch_size
):
self.user_ntype = user_ntype self.user_ntype = user_ntype
self.item_ntype = item_ntype self.item_ntype = item_ntype
self.user_to_item_etype = user_to_item_etype self.user_to_item_etype = user_to_item_etype
...@@ -27,19 +33,27 @@ class LatestNNRecommender(object): ...@@ -27,19 +33,27 @@ class LatestNNRecommender(object):
""" """
graph_slice = full_graph.edge_type_subgraph([self.user_to_item_etype]) graph_slice = full_graph.edge_type_subgraph([self.user_to_item_etype])
n_users = full_graph.num_nodes(self.user_ntype) n_users = full_graph.num_nodes(self.user_ntype)
latest_interactions = dgl.sampling.select_topk(graph_slice, 1, self.timestamp, edge_dir='out') latest_interactions = dgl.sampling.select_topk(
user, latest_items = latest_interactions.all_edges(form='uv', order='srcdst') graph_slice, 1, self.timestamp, edge_dir="out"
)
user, latest_items = latest_interactions.all_edges(
form="uv", order="srcdst"
)
# each user should have at least one "latest" interaction # each user should have at least one "latest" interaction
assert torch.equal(user, torch.arange(n_users)) assert torch.equal(user, torch.arange(n_users))
recommended_batches = [] recommended_batches = []
user_batches = torch.arange(n_users).split(self.batch_size) user_batches = torch.arange(n_users).split(self.batch_size)
for user_batch in user_batches: for user_batch in user_batches:
latest_item_batch = latest_items[user_batch].to(device=h_item.device) latest_item_batch = latest_items[user_batch].to(
device=h_item.device
)
dist = h_item[latest_item_batch] @ h_item.t() dist = h_item[latest_item_batch] @ h_item.t()
# exclude items that are already interacted # exclude items that are already interacted
for i, u in enumerate(user_batch.tolist()): for i, u in enumerate(user_batch.tolist()):
interacted_items = full_graph.successors(u, etype=self.user_to_item_etype) interacted_items = full_graph.successors(
u, etype=self.user_to_item_etype
)
dist[i, interacted_items] = -np.inf dist[i, interacted_items] = -np.inf
recommended_batches.append(dist.topk(K, 1)[1]) recommended_batches.append(dist.topk(K, 1)[1])
...@@ -48,31 +62,33 @@ class LatestNNRecommender(object): ...@@ -48,31 +62,33 @@ class LatestNNRecommender(object):
def evaluate_nn(dataset, h_item, k, batch_size): def evaluate_nn(dataset, h_item, k, batch_size):
g = dataset['train-graph'] g = dataset["train-graph"]
val_matrix = dataset['val-matrix'].tocsr() val_matrix = dataset["val-matrix"].tocsr()
test_matrix = dataset['test-matrix'].tocsr() test_matrix = dataset["test-matrix"].tocsr()
item_texts = dataset['item-texts'] item_texts = dataset["item-texts"]
user_ntype = dataset['user-type'] user_ntype = dataset["user-type"]
item_ntype = dataset['item-type'] item_ntype = dataset["item-type"]
user_to_item_etype = dataset['user-to-item-type'] user_to_item_etype = dataset["user-to-item-type"]
timestamp = dataset['timestamp-edge-column'] timestamp = dataset["timestamp-edge-column"]
rec_engine = LatestNNRecommender( rec_engine = LatestNNRecommender(
user_ntype, item_ntype, user_to_item_etype, timestamp, batch_size) user_ntype, item_ntype, user_to_item_etype, timestamp, batch_size
)
recommendations = rec_engine.recommend(g, k, None, h_item).cpu().numpy() recommendations = rec_engine.recommend(g, k, None, h_item).cpu().numpy()
return prec(recommendations, val_matrix) return prec(recommendations, val_matrix)
if __name__ == '__main__':
if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('dataset_path', type=str) parser.add_argument("dataset_path", type=str)
parser.add_argument('item_embedding_path', type=str) parser.add_argument("item_embedding_path", type=str)
parser.add_argument('-k', type=int, default=10) parser.add_argument("-k", type=int, default=10)
parser.add_argument('--batch-size', type=int, default=32) parser.add_argument("--batch-size", type=int, default=32)
args = parser.parse_args() args = parser.parse_args()
with open(args.dataset_path, 'rb') as f: with open(args.dataset_path, "rb") as f:
dataset = pickle.load(f) dataset = pickle.load(f)
with open(args.item_embedding_path, 'rb') as f: with open(args.item_embedding_path, "rb") as f:
emb = torch.FloatTensor(pickle.load(f)) emb = torch.FloatTensor(pickle.load(f))
print(evaluate_nn(dataset, emb, args.k, args.batch_size)) print(evaluate_nn(dataset, emb, args.k, args.batch_size))
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import dgl import dgl
import dgl.nn.pytorch as dglnn
import dgl.function as fn import dgl.function as fn
import dgl.nn.pytorch as dglnn
def disable_grad(module): def disable_grad(module):
for param in module.parameters(): for param in module.parameters():
param.requires_grad = False param.requires_grad = False
def _init_input_modules(g, ntype, textset, hidden_dims): def _init_input_modules(g, ntype, textset, hidden_dims):
# We initialize the linear projections of each input feature ``x`` as # We initialize the linear projections of each input feature ``x`` as
# follows: # follows:
...@@ -30,44 +33,50 @@ def _init_input_modules(g, ntype, textset, hidden_dims): ...@@ -30,44 +33,50 @@ def _init_input_modules(g, ntype, textset, hidden_dims):
module_dict[column] = m module_dict[column] = m
elif data.dtype == torch.int64: elif data.dtype == torch.int64:
assert data.ndim == 1 assert data.ndim == 1
m = nn.Embedding( m = nn.Embedding(data.max() + 2, hidden_dims, padding_idx=-1)
data.max() + 2, hidden_dims, padding_idx=-1)
nn.init.xavier_uniform_(m.weight) nn.init.xavier_uniform_(m.weight)
module_dict[column] = m module_dict[column] = m
if textset is not None: if textset is not None:
for column, field in textset.items(): for column, field in textset.items():
textlist, vocab, pad_var, batch_first = field textlist, vocab, pad_var, batch_first = field
module_dict[column] = BagOfWords(vocab, hidden_dims) module_dict[column] = BagOfWords(vocab, hidden_dims)
return module_dict return module_dict
class BagOfWords(nn.Module): class BagOfWords(nn.Module):
def __init__(self, vocab, hidden_dims): def __init__(self, vocab, hidden_dims):
super().__init__() super().__init__()
self.emb = nn.Embedding( self.emb = nn.Embedding(
len(vocab.get_itos()), hidden_dims, len(vocab.get_itos()),
padding_idx=vocab.get_stoi()['<pad>']) hidden_dims,
padding_idx=vocab.get_stoi()["<pad>"],
)
nn.init.xavier_uniform_(self.emb.weight) nn.init.xavier_uniform_(self.emb.weight)
def forward(self, x, length): def forward(self, x, length):
return self.emb(x).sum(1) / length.unsqueeze(1).float() return self.emb(x).sum(1) / length.unsqueeze(1).float()
class LinearProjector(nn.Module): class LinearProjector(nn.Module):
""" """
Projects each input feature of the graph linearly and sums them up Projects each input feature of the graph linearly and sums them up
""" """
def __init__(self, full_graph, ntype, textset, hidden_dims): def __init__(self, full_graph, ntype, textset, hidden_dims):
super().__init__() super().__init__()
self.ntype = ntype self.ntype = ntype
self.inputs = _init_input_modules(full_graph, ntype, textset, hidden_dims) self.inputs = _init_input_modules(
full_graph, ntype, textset, hidden_dims
)
def forward(self, ndata): def forward(self, ndata):
projections = [] projections = []
for feature, data in ndata.items(): for feature, data in ndata.items():
if feature == dgl.NID or feature.endswith('__len'): if feature == dgl.NID or feature.endswith("__len"):
# This is an additional feature indicating the length of the ``feature`` # This is an additional feature indicating the length of the ``feature``
# column; we shouldn't process this. # column; we shouldn't process this.
continue continue
...@@ -75,7 +84,7 @@ class LinearProjector(nn.Module): ...@@ -75,7 +84,7 @@ class LinearProjector(nn.Module):
module = self.inputs[feature] module = self.inputs[feature]
if isinstance(module, BagOfWords): if isinstance(module, BagOfWords):
# Textual feature; find the length and pass it to the textual module. # Textual feature; find the length and pass it to the textual module.
length = ndata[feature + '__len'] length = ndata[feature + "__len"]
result = module(data, length) result = module(data, length)
else: else:
result = module(data) result = module(data)
...@@ -83,6 +92,7 @@ class LinearProjector(nn.Module): ...@@ -83,6 +92,7 @@ class LinearProjector(nn.Module):
return torch.stack(projections, 1).sum(1) return torch.stack(projections, 1).sum(1)
class WeightedSAGEConv(nn.Module): class WeightedSAGEConv(nn.Module):
def __init__(self, input_dims, hidden_dims, output_dims, act=F.relu): def __init__(self, input_dims, hidden_dims, output_dims, act=F.relu):
super().__init__() super().__init__()
...@@ -94,7 +104,7 @@ class WeightedSAGEConv(nn.Module): ...@@ -94,7 +104,7 @@ class WeightedSAGEConv(nn.Module):
self.dropout = nn.Dropout(0.5) self.dropout = nn.Dropout(0.5)
def reset_parameters(self): def reset_parameters(self):
gain = nn.init.calculate_gain('relu') gain = nn.init.calculate_gain("relu")
nn.init.xavier_uniform_(self.Q.weight, gain=gain) nn.init.xavier_uniform_(self.Q.weight, gain=gain)
nn.init.xavier_uniform_(self.W.weight, gain=gain) nn.init.xavier_uniform_(self.W.weight, gain=gain)
nn.init.constant_(self.Q.bias, 0) nn.init.constant_(self.Q.bias, 0)
...@@ -108,18 +118,21 @@ class WeightedSAGEConv(nn.Module): ...@@ -108,18 +118,21 @@ class WeightedSAGEConv(nn.Module):
""" """
h_src, h_dst = h h_src, h_dst = h
with g.local_scope(): with g.local_scope():
g.srcdata['n'] = self.act(self.Q(self.dropout(h_src))) g.srcdata["n"] = self.act(self.Q(self.dropout(h_src)))
g.edata['w'] = weights.float() g.edata["w"] = weights.float()
g.update_all(fn.u_mul_e('n', 'w', 'm'), fn.sum('m', 'n')) g.update_all(fn.u_mul_e("n", "w", "m"), fn.sum("m", "n"))
g.update_all(fn.copy_e('w', 'm'), fn.sum('m', 'ws')) g.update_all(fn.copy_e("w", "m"), fn.sum("m", "ws"))
n = g.dstdata['n'] n = g.dstdata["n"]
ws = g.dstdata['ws'].unsqueeze(1).clamp(min=1) ws = g.dstdata["ws"].unsqueeze(1).clamp(min=1)
z = self.act(self.W(self.dropout(torch.cat([n / ws, h_dst], 1)))) z = self.act(self.W(self.dropout(torch.cat([n / ws, h_dst], 1))))
z_norm = z.norm(2, 1, keepdim=True) z_norm = z.norm(2, 1, keepdim=True)
z_norm = torch.where(z_norm == 0, torch.tensor(1.).to(z_norm), z_norm) z_norm = torch.where(
z_norm == 0, torch.tensor(1.0).to(z_norm), z_norm
)
z = z / z_norm z = z / z_norm
return z return z
class SAGENet(nn.Module): class SAGENet(nn.Module):
def __init__(self, hidden_dims, n_layers): def __init__(self, hidden_dims, n_layers):
""" """
...@@ -133,14 +146,17 @@ class SAGENet(nn.Module): ...@@ -133,14 +146,17 @@ class SAGENet(nn.Module):
self.convs = nn.ModuleList() self.convs = nn.ModuleList()
for _ in range(n_layers): for _ in range(n_layers):
self.convs.append(WeightedSAGEConv(hidden_dims, hidden_dims, hidden_dims)) self.convs.append(
WeightedSAGEConv(hidden_dims, hidden_dims, hidden_dims)
)
def forward(self, blocks, h): def forward(self, blocks, h):
for layer, block in zip(self.convs, blocks): for layer, block in zip(self.convs, blocks):
h_dst = h[:block.num_nodes('DST/' + block.ntypes[0])] h_dst = h[: block.num_nodes("DST/" + block.ntypes[0])]
h = layer(block, (h, h_dst), block.edata['weights']) h = layer(block, (h, h_dst), block.edata["weights"])
return h return h
class ItemToItemScorer(nn.Module): class ItemToItemScorer(nn.Module):
def __init__(self, full_graph, ntype): def __init__(self, full_graph, ntype):
super().__init__() super().__init__()
...@@ -151,7 +167,7 @@ class ItemToItemScorer(nn.Module): ...@@ -151,7 +167,7 @@ class ItemToItemScorer(nn.Module):
def _add_bias(self, edges): def _add_bias(self, edges):
bias_src = self.bias[edges.src[dgl.NID]] bias_src = self.bias[edges.src[dgl.NID]]
bias_dst = self.bias[edges.dst[dgl.NID]] bias_dst = self.bias[edges.dst[dgl.NID]]
return {'s': edges.data['s'] + bias_src + bias_dst} return {"s": edges.data["s"] + bias_src + bias_dst}
def forward(self, item_item_graph, h): def forward(self, item_item_graph, h):
""" """
...@@ -159,8 +175,8 @@ class ItemToItemScorer(nn.Module): ...@@ -159,8 +175,8 @@ class ItemToItemScorer(nn.Module):
h : hidden state of every node h : hidden state of every node
""" """
with item_item_graph.local_scope(): with item_item_graph.local_scope():
item_item_graph.ndata['h'] = h item_item_graph.ndata["h"] = h
item_item_graph.apply_edges(fn.u_dot_v('h', 'h', 's')) item_item_graph.apply_edges(fn.u_dot_v("h", "h", "s"))
item_item_graph.apply_edges(self._add_bias) item_item_graph.apply_edges(self._add_bias)
pair_score = item_item_graph.edata['s'] pair_score = item_item_graph.edata["s"]
return pair_score return pair_score
import pickle
import argparse import argparse
import os
import pickle
import evaluation
import layers
import numpy as np import numpy as np
import sampler as sampler_module
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.utils.data import DataLoader
import torchtext import torchtext
import dgl
import os
import tqdm import tqdm
from torch.utils.data import DataLoader
import layers
import sampler as sampler_module
import evaluation
from torchtext.data.utils import get_tokenizer from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator from torchtext.vocab import build_vocab_from_iterator
import dgl
class PinSAGEModel(nn.Module): class PinSAGEModel(nn.Module):
def __init__(self, full_graph, ntype, textsets, hidden_dims, n_layers): def __init__(self, full_graph, ntype, textsets, hidden_dims, n_layers):
super().__init__() super().__init__()
self.proj = layers.LinearProjector(full_graph, ntype, textsets, hidden_dims) self.proj = layers.LinearProjector(
full_graph, ntype, textsets, hidden_dims
)
self.sage = layers.SAGENet(hidden_dims, n_layers) self.sage = layers.SAGENet(hidden_dims, n_layers)
self.scorer = layers.ItemToItemScorer(full_graph, ntype) self.scorer = layers.ItemToItemScorer(full_graph, ntype)
...@@ -36,19 +40,22 @@ class PinSAGEModel(nn.Module): ...@@ -36,19 +40,22 @@ class PinSAGEModel(nn.Module):
# add to the item embedding itself # add to the item embedding itself
h_item = h_item + item_emb(blocks[0].srcdata[dgl.NID].cpu()).to(h_item) h_item = h_item + item_emb(blocks[0].srcdata[dgl.NID].cpu()).to(h_item)
h_item_dst = h_item_dst + item_emb(blocks[-1].dstdata[dgl.NID].cpu()).to(h_item_dst) h_item_dst = h_item_dst + item_emb(
blocks[-1].dstdata[dgl.NID].cpu()
).to(h_item_dst)
return h_item_dst + self.sage(blocks, h_item) return h_item_dst + self.sage(blocks, h_item)
def train(dataset, args): def train(dataset, args):
g = dataset['train-graph'] g = dataset["train-graph"]
val_matrix = dataset['val-matrix'].tocsr() val_matrix = dataset["val-matrix"].tocsr()
test_matrix = dataset['test-matrix'].tocsr() test_matrix = dataset["test-matrix"].tocsr()
item_texts = dataset['item-texts'] item_texts = dataset["item-texts"]
user_ntype = dataset['user-type'] user_ntype = dataset["user-type"]
item_ntype = dataset['item-type'] item_ntype = dataset["item-type"]
user_to_item_etype = dataset['user-to-item-type'] user_to_item_etype = dataset["user-to-item-type"]
timestamp = dataset['timestamp-edge-column'] timestamp = dataset["timestamp-edge-column"]
device = torch.device(args.device) device = torch.device(args.device)
...@@ -64,31 +71,53 @@ def train(dataset, args): ...@@ -64,31 +71,53 @@ def train(dataset, args):
l = tokenizer(item_texts[key][i].lower()) l = tokenizer(item_texts[key][i].lower())
textlist.append(l) textlist.append(l)
for key, field in item_texts.items(): for key, field in item_texts.items():
vocab2 = build_vocab_from_iterator(textlist, specials=["<unk>","<pad>"]) vocab2 = build_vocab_from_iterator(
textset[key] = (textlist, vocab2, vocab2.get_stoi()['<pad>'], batch_first) textlist, specials=["<unk>", "<pad>"]
)
textset[key] = (
textlist,
vocab2,
vocab2.get_stoi()["<pad>"],
batch_first,
)
# Sampler # Sampler
batch_sampler = sampler_module.ItemToItemBatchSampler( batch_sampler = sampler_module.ItemToItemBatchSampler(
g, user_ntype, item_ntype, args.batch_size) g, user_ntype, item_ntype, args.batch_size
)
neighbor_sampler = sampler_module.NeighborSampler( neighbor_sampler = sampler_module.NeighborSampler(
g, user_ntype, item_ntype, args.random_walk_length, g,
args.random_walk_restart_prob, args.num_random_walks, args.num_neighbors, user_ntype,
args.num_layers) item_ntype,
collator = sampler_module.PinSAGECollator(neighbor_sampler, g, item_ntype, textset) args.random_walk_length,
args.random_walk_restart_prob,
args.num_random_walks,
args.num_neighbors,
args.num_layers,
)
collator = sampler_module.PinSAGECollator(
neighbor_sampler, g, item_ntype, textset
)
dataloader = DataLoader( dataloader = DataLoader(
batch_sampler, batch_sampler,
collate_fn=collator.collate_train, collate_fn=collator.collate_train,
num_workers=args.num_workers) num_workers=args.num_workers,
)
dataloader_test = DataLoader( dataloader_test = DataLoader(
torch.arange(g.num_nodes(item_ntype)), torch.arange(g.num_nodes(item_ntype)),
batch_size=args.batch_size, batch_size=args.batch_size,
collate_fn=collator.collate_test, collate_fn=collator.collate_test,
num_workers=args.num_workers) num_workers=args.num_workers,
)
dataloader_it = iter(dataloader) dataloader_it = iter(dataloader)
# Model # Model
model = PinSAGEModel(g, item_ntype, textset, args.hidden_dims, args.num_layers).to(device) model = PinSAGEModel(
item_emb = nn.Embedding(g.num_nodes(item_ntype), args.hidden_dims, sparse=True) g, item_ntype, textset, args.hidden_dims, args.num_layers
).to(device)
item_emb = nn.Embedding(
g.num_nodes(item_ntype), args.hidden_dims, sparse=True
)
# Optimizer # Optimizer
opt = torch.optim.Adam(model.parameters(), lr=args.lr) opt = torch.optim.Adam(model.parameters(), lr=args.lr)
opt_emb = torch.optim.SparseAdam(item_emb.parameters(), lr=args.lr) opt_emb = torch.optim.SparseAdam(item_emb.parameters(), lr=args.lr)
...@@ -114,7 +143,9 @@ def train(dataset, args): ...@@ -114,7 +143,9 @@ def train(dataset, args):
# Evaluate # Evaluate
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
item_batches = torch.arange(g.num_nodes(item_ntype)).split(args.batch_size) item_batches = torch.arange(g.num_nodes(item_ntype)).split(
args.batch_size
)
h_item_batches = [] h_item_batches = []
for blocks in tqdm.tqdm(dataloader_test): for blocks in tqdm.tqdm(dataloader_test):
for i in range(len(blocks)): for i in range(len(blocks)):
...@@ -123,32 +154,37 @@ def train(dataset, args): ...@@ -123,32 +154,37 @@ def train(dataset, args):
h_item_batches.append(model.get_repr(blocks, item_emb)) h_item_batches.append(model.get_repr(blocks, item_emb))
h_item = torch.cat(h_item_batches, 0) h_item = torch.cat(h_item_batches, 0)
print(evaluation.evaluate_nn(dataset, h_item, args.k, args.batch_size)) print(
evaluation.evaluate_nn(dataset, h_item, args.k, args.batch_size)
)
if __name__ == '__main__': if __name__ == "__main__":
# Arguments # Arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('dataset_path', type=str) parser.add_argument("dataset_path", type=str)
parser.add_argument('--random-walk-length', type=int, default=2) parser.add_argument("--random-walk-length", type=int, default=2)
parser.add_argument('--random-walk-restart-prob', type=float, default=0.5) parser.add_argument("--random-walk-restart-prob", type=float, default=0.5)
parser.add_argument('--num-random-walks', type=int, default=10) parser.add_argument("--num-random-walks", type=int, default=10)
parser.add_argument('--num-neighbors', type=int, default=3) parser.add_argument("--num-neighbors", type=int, default=3)
parser.add_argument('--num-layers', type=int, default=2) parser.add_argument("--num-layers", type=int, default=2)
parser.add_argument('--hidden-dims', type=int, default=16) parser.add_argument("--hidden-dims", type=int, default=16)
parser.add_argument('--batch-size', type=int, default=32) parser.add_argument("--batch-size", type=int, default=32)
parser.add_argument('--device', type=str, default='cpu') # can also be "cuda:0" parser.add_argument(
parser.add_argument('--num-epochs', type=int, default=1) "--device", type=str, default="cpu"
parser.add_argument('--batches-per-epoch', type=int, default=20000) ) # can also be "cuda:0"
parser.add_argument('--num-workers', type=int, default=0) parser.add_argument("--num-epochs", type=int, default=1)
parser.add_argument('--lr', type=float, default=3e-5) parser.add_argument("--batches-per-epoch", type=int, default=20000)
parser.add_argument('-k', type=int, default=10) parser.add_argument("--num-workers", type=int, default=0)
parser.add_argument("--lr", type=float, default=3e-5)
parser.add_argument("-k", type=int, default=10)
args = parser.parse_args() args = parser.parse_args()
# Load dataset # Load dataset
data_info_path = os.path.join(args.dataset_path, 'data.pkl') data_info_path = os.path.join(args.dataset_path, "data.pkl")
with open(data_info_path, 'rb') as f: with open(data_info_path, "rb") as f:
dataset = pickle.load(f) dataset = pickle.load(f)
train_g_path = os.path.join(args.dataset_path, 'train_g.bin') train_g_path = os.path.join(args.dataset_path, "train_g.bin")
g_list, _ = dgl.load_graphs(train_g_path) g_list, _ = dgl.load_graphs(train_g_path)
dataset['train-graph'] = g_list[0] dataset["train-graph"] = g_list[0]
train(dataset, args) train(dataset, args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment