Unverified Commit 0b9df9d7 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Black auto fix. (#4652)


Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent f19f05ce
......@@ -2,39 +2,39 @@
import torch
MWE_GCN_proteins = {
'num_ew_channels': 8,
'num_epochs': 2000,
'in_feats': 1,
'hidden_feats': 10,
'out_feats': 112,
'n_layers': 3,
'lr': 2e-2,
'weight_decay': 0,
'patience': 1000,
'dropout': 0.2,
'aggr_mode': 'sum', ## 'sum' or 'concat' for the aggregation across channels
'ewnorm': 'both'
}
"num_ew_channels": 8,
"num_epochs": 2000,
"in_feats": 1,
"hidden_feats": 10,
"out_feats": 112,
"n_layers": 3,
"lr": 2e-2,
"weight_decay": 0,
"patience": 1000,
"dropout": 0.2,
"aggr_mode": "sum", ## 'sum' or 'concat' for the aggregation across channels
"ewnorm": "both",
}
MWE_DGCN_proteins = {
'num_ew_channels': 8,
'num_epochs': 2000,
'in_feats': 1,
'hidden_feats': 10,
'out_feats': 112,
'n_layers': 2,
'lr': 1e-2,
'weight_decay': 0,
'patience': 300,
'dropout': 0.5,
'aggr_mode': 'sum',
'residual': True,
'ewnorm': 'none'
}
"num_ew_channels": 8,
"num_epochs": 2000,
"in_feats": 1,
"hidden_feats": 10,
"out_feats": 112,
"n_layers": 2,
"lr": 1e-2,
"weight_decay": 0,
"patience": 300,
"dropout": 0.5,
"aggr_mode": "sum",
"residual": True,
"ewnorm": "none",
}
def get_exp_configure(args):
if (args['model'] == 'MWE-GCN'):
if args["model"] == "MWE-GCN":
return MWE_GCN_proteins
elif (args['model'] == 'MWE-DGCN'):
elif args["model"] == "MWE-DGCN":
return MWE_DGCN_proteins
......@@ -7,20 +7,23 @@ import random
import sys
import time
import dgl
import dgl.function as fn
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
from dgl.dataloading import MultiLayerFullNeighborSampler, MultiLayerNeighborSampler
from dgl.dataloading import DataLoader
from matplotlib.ticker import AutoMinorLocator, MultipleLocator
from models import GAT
from ogb.nodeproppred import DglNodePropPredDataset, Evaluator
from torch import nn
from models import GAT
import dgl
import dgl.function as fn
from dgl.dataloading import (
DataLoader,
MultiLayerFullNeighborSampler,
MultiLayerNeighborSampler,
)
device = None
dataset = "ogbn-proteins"
......@@ -43,7 +46,11 @@ def load_data(dataset):
evaluator = Evaluator(name=dataset)
splitted_idx = data.get_idx_split()
train_idx, val_idx, test_idx = splitted_idx["train"], splitted_idx["valid"], splitted_idx["test"]
train_idx, val_idx, test_idx = (
splitted_idx["train"],
splitted_idx["valid"],
splitted_idx["test"],
)
graph, labels = data[0]
graph.ndata["labels"] = labels
......@@ -54,11 +61,15 @@ def preprocess(graph, labels, train_idx):
global n_node_feats
# The sum of the weights of adjacent edges is used as node features.
graph.update_all(fn.copy_e("feat", "feat_copy"), fn.sum("feat_copy", "feat"))
graph.update_all(
fn.copy_e("feat", "feat_copy"), fn.sum("feat_copy", "feat")
)
n_node_feats = graph.ndata["feat"].shape[-1]
# Only the labels in the training set are used as features, while others are filled with zeros.
graph.ndata["train_labels_onehot"] = torch.zeros(graph.number_of_nodes(), n_classes)
graph.ndata["train_labels_onehot"] = torch.zeros(
graph.number_of_nodes(), n_classes
)
graph.ndata["train_labels_onehot"][train_idx, labels[train_idx, 0]] = 1
graph.ndata["deg"] = graph.out_degrees().float().clamp(min=1)
......@@ -99,7 +110,16 @@ def add_labels(graph, idx):
graph.srcdata["feat"] = torch.cat([feat, train_labels_onehot], dim=-1)
def train(args, model, dataloader, _labels, _train_idx, criterion, optimizer, _evaluator):
def train(
args,
model,
dataloader,
_labels,
_train_idx,
criterion,
optimizer,
_evaluator,
):
model.train()
loss_sum, total = 0, 0
......@@ -109,7 +129,9 @@ def train(args, model, dataloader, _labels, _train_idx, criterion, optimizer, _e
new_train_idx = torch.arange(len(output_nodes), device=device)
if args.use_labels:
train_labels_idx = torch.arange(len(output_nodes), len(input_nodes), device=device)
train_labels_idx = torch.arange(
len(output_nodes), len(input_nodes), device=device
)
train_pred_idx = new_train_idx
add_labels(subgraphs[0], train_labels_idx)
......@@ -117,7 +139,10 @@ def train(args, model, dataloader, _labels, _train_idx, criterion, optimizer, _e
train_pred_idx = new_train_idx
pred = model(subgraphs)
loss = criterion(pred[train_pred_idx], subgraphs[-1].dstdata["labels"][train_pred_idx].float())
loss = criterion(
pred[train_pred_idx],
subgraphs[-1].dstdata["labels"][train_pred_idx].float(),
)
optimizer.zero_grad()
loss.backward()
optimizer.step()
......@@ -132,7 +157,17 @@ def train(args, model, dataloader, _labels, _train_idx, criterion, optimizer, _e
@torch.no_grad()
def evaluate(args, model, dataloader, labels, train_idx, val_idx, test_idx, criterion, evaluator):
def evaluate(
args,
model,
dataloader,
labels,
train_idx,
val_idx,
test_idx,
criterion,
evaluator,
):
model.eval()
preds = torch.zeros(labels.shape).to(device)
......@@ -170,37 +205,49 @@ def evaluate(args, model, dataloader, labels, train_idx, val_idx, test_idx, crit
)
def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running):
evaluator_wrapper = lambda pred, labels: evaluator.eval({"y_pred": pred, "y_true": labels})["rocauc"]
def run(
args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running
):
evaluator_wrapper = lambda pred, labels: evaluator.eval(
{"y_pred": pred, "y_true": labels}
)["rocauc"]
train_batch_size = (len(train_idx) + 9) // 10
# batch_size = len(train_idx)
train_sampler = MultiLayerNeighborSampler([32 for _ in range(args.n_layers)])
train_sampler = MultiLayerNeighborSampler(
[32 for _ in range(args.n_layers)]
)
# sampler = MultiLayerFullNeighborSampler(args.n_layers)
train_dataloader = DataLoader(
graph.cpu(),
train_idx.cpu(),
train_sampler,
batch_size=train_batch_size,
num_workers=10,
num_workers=10,
)
eval_sampler = MultiLayerNeighborSampler([100 for _ in range(args.n_layers)])
eval_sampler = MultiLayerNeighborSampler(
[100 for _ in range(args.n_layers)]
)
# sampler = MultiLayerFullNeighborSampler(args.n_layers)
eval_dataloader = DataLoader(
eval_dataloader = DataLoader(
graph.cpu(),
torch.cat([train_idx.cpu(), val_idx.cpu(), test_idx.cpu()]),
eval_sampler,
batch_size=65536,
num_workers=10,
num_workers=10,
)
criterion = nn.BCEWithLogitsLoss()
model = gen_model(args).to(device)
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wd)
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="max", factor=0.75, patience=50, verbose=True)
optimizer = optim.AdamW(
model.parameters(), lr=args.lr, weight_decay=args.wd
)
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode="max", factor=0.75, patience=50, verbose=True
)
total_time = 0
val_score, best_val_score, final_test_score = 0, 0, 0
......@@ -212,14 +259,43 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running)
for epoch in range(1, args.n_epochs + 1):
tic = time.time()
loss = train(args, model, train_dataloader, labels, train_idx, criterion, optimizer, evaluator_wrapper)
loss = train(
args,
model,
train_dataloader,
labels,
train_idx,
criterion,
optimizer,
evaluator_wrapper,
)
toc = time.time()
total_time += toc - tic
if epoch == args.n_epochs or epoch % args.eval_every == 0 or epoch % args.log_every == 0:
train_score, val_score, test_score, train_loss, val_loss, test_loss, pred = evaluate(
args, model, eval_dataloader, labels, train_idx, val_idx, test_idx, criterion, evaluator_wrapper
if (
epoch == args.n_epochs
or epoch % args.eval_every == 0
or epoch % args.log_every == 0
):
(
train_score,
val_score,
test_score,
train_loss,
val_loss,
test_loss,
pred,
) = evaluate(
args,
model,
eval_dataloader,
labels,
train_idx,
val_idx,
test_idx,
criterion,
evaluator_wrapper,
)
if val_score > best_val_score:
......@@ -238,15 +314,33 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running)
)
for l, e in zip(
[train_scores, val_scores, test_scores, losses, train_losses, val_losses, test_losses],
[train_score, val_score, test_score, loss, train_loss, val_loss, test_loss],
[
train_scores,
val_scores,
test_scores,
losses,
train_losses,
val_losses,
test_losses,
],
[
train_score,
val_score,
test_score,
loss,
train_loss,
val_loss,
test_loss,
],
):
l.append(e)
lr_scheduler.step(val_score)
print("*" * 50)
print(f"Best val score: {best_val_score}, Final test score: {final_test_score}")
print(
f"Best val score: {best_val_score}, Final test score: {final_test_score}"
)
print("*" * 50)
if args.plot:
......@@ -255,8 +349,16 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running)
ax.set_xticks(np.arange(0, args.n_epochs, 100))
ax.set_yticks(np.linspace(0, 1.0, 101))
ax.tick_params(labeltop=True, labelright=True)
for y, label in zip([train_scores, val_scores, test_scores], ["train score", "val score", "test score"]):
plt.plot(range(1, args.n_epochs + 1, args.log_every), y, label=label, linewidth=1)
for y, label in zip(
[train_scores, val_scores, test_scores],
["train score", "val score", "test score"],
):
plt.plot(
range(1, args.n_epochs + 1, args.log_every),
y,
label=label,
linewidth=1,
)
ax.xaxis.set_major_locator(MultipleLocator(100))
ax.xaxis.set_minor_locator(AutoMinorLocator(1))
ax.yaxis.set_major_locator(MultipleLocator(0.01))
......@@ -272,9 +374,15 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running)
ax.set_xticks(np.arange(0, args.n_epochs, 100))
ax.tick_params(labeltop=True, labelright=True)
for y, label in zip(
[losses, train_losses, val_losses, test_losses], ["loss", "train loss", "val loss", "test loss"]
[losses, train_losses, val_losses, test_losses],
["loss", "train loss", "val loss", "test loss"],
):
plt.plot(range(1, args.n_epochs + 1, args.log_every), y, label=label, linewidth=1)
plt.plot(
range(1, args.n_epochs + 1, args.log_every),
y,
label=label,
linewidth=1,
)
ax.xaxis.set_major_locator(MultipleLocator(100))
ax.xaxis.set_minor_locator(AutoMinorLocator(1))
ax.yaxis.set_major_locator(MultipleLocator(0.1))
......@@ -294,37 +402,79 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running)
def count_parameters(args):
model = gen_model(args)
return sum([np.prod(p.size()) for p in model.parameters() if p.requires_grad])
return sum(
[np.prod(p.size()) for p in model.parameters() if p.requires_grad]
)
def main():
global device
argparser = argparse.ArgumentParser(
"GAT implementation on ogbn-proteins", formatter_class=argparse.ArgumentDefaultsHelpFormatter
"GAT implementation on ogbn-proteins",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
argparser.add_argument(
"--cpu",
action="store_true",
help="CPU mode. This option overrides '--gpu'.",
)
argparser.add_argument("--cpu", action="store_true", help="CPU mode. This option overrides '--gpu'.")
argparser.add_argument("--gpu", type=int, default=0, help="GPU device ID")
argparser.add_argument("--seed", type=int, default=0, help="random seed")
argparser.add_argument("--n-runs", type=int, default=10, help="running times")
argparser.add_argument("--n-epochs", type=int, default=1200, help="number of epochs")
argparser.add_argument(
"--use-labels", action="store_true", help="Use labels in the training set as input features."
)
argparser.add_argument("--no-attn-dst", action="store_true", help="Don't use attn_dst.")
argparser.add_argument("--n-heads", type=int, default=6, help="number of heads")
argparser.add_argument("--lr", type=float, default=0.01, help="learning rate")
argparser.add_argument("--n-layers", type=int, default=6, help="number of layers")
argparser.add_argument("--n-hidden", type=int, default=80, help="number of hidden units")
argparser.add_argument("--dropout", type=float, default=0.25, help="dropout rate")
argparser.add_argument("--input-drop", type=float, default=0.1, help="input drop rate")
argparser.add_argument("--attn-drop", type=float, default=0.0, help="attention dropout rate")
argparser.add_argument("--edge-drop", type=float, default=0.1, help="edge drop rate")
"--n-runs", type=int, default=10, help="running times"
)
argparser.add_argument(
"--n-epochs", type=int, default=1200, help="number of epochs"
)
argparser.add_argument(
"--use-labels",
action="store_true",
help="Use labels in the training set as input features.",
)
argparser.add_argument(
"--no-attn-dst", action="store_true", help="Don't use attn_dst."
)
argparser.add_argument(
"--n-heads", type=int, default=6, help="number of heads"
)
argparser.add_argument(
"--lr", type=float, default=0.01, help="learning rate"
)
argparser.add_argument(
"--n-layers", type=int, default=6, help="number of layers"
)
argparser.add_argument(
"--n-hidden", type=int, default=80, help="number of hidden units"
)
argparser.add_argument(
"--dropout", type=float, default=0.25, help="dropout rate"
)
argparser.add_argument(
"--input-drop", type=float, default=0.1, help="input drop rate"
)
argparser.add_argument(
"--attn-drop", type=float, default=0.0, help="attention dropout rate"
)
argparser.add_argument(
"--edge-drop", type=float, default=0.1, help="edge drop rate"
)
argparser.add_argument("--wd", type=float, default=0, help="weight decay")
argparser.add_argument("--eval-every", type=int, default=5, help="evaluate every EVAL_EVERY epochs")
argparser.add_argument("--log-every", type=int, default=5, help="log every LOG_EVERY epochs")
argparser.add_argument("--plot", action="store_true", help="plot learning curves")
argparser.add_argument("--save-pred", action="store_true", help="save final predictions")
argparser.add_argument(
"--eval-every",
type=int,
default=5,
help="evaluate every EVAL_EVERY epochs",
)
argparser.add_argument(
"--log-every", type=int, default=5, help="log every LOG_EVERY epochs"
)
argparser.add_argument(
"--plot", action="store_true", help="plot learning curves"
)
argparser.add_argument(
"--save-pred", action="store_true", help="save final predictions"
)
args = argparser.parse_args()
if args.cpu:
......@@ -338,7 +488,9 @@ def main():
print("Preprocessing")
graph, labels = preprocess(graph, labels, train_idx)
labels, train_idx, val_idx, test_idx = map(lambda x: x.to(device), (labels, train_idx, val_idx, test_idx))
labels, train_idx, val_idx, test_idx = map(
lambda x: x.to(device), (labels, train_idx, val_idx, test_idx)
)
# run
val_scores, test_scores = [], []
......@@ -346,7 +498,9 @@ def main():
for i in range(args.n_runs):
print("Running", i)
seed(args.seed + i)
val_score, test_score = run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, i + 1)
val_score, test_score = run(
args, graph, labels, train_idx, val_idx, test_idx, evaluator, i + 1
)
val_scores.append(val_score)
test_scores.append(test_score)
......
import os
import time
import dgl.function as fn
import numpy as np
import torch
import torch.nn as nn
......@@ -11,9 +10,10 @@ from ogb.nodeproppred.dataset_dgl import DglNodePropPredDataset
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.tensorboard import SummaryWriter
from utils import load_model, set_random_seed
import dgl.function as fn
def normalize_edge_weights(graph, device, num_ew_channels):
degs = graph.in_degrees().float()
......@@ -24,7 +24,9 @@ def normalize_edge_weights(graph, device, num_ew_channels):
graph.apply_edges(fn.e_div_u("feat", "norm", "feat"))
graph.apply_edges(fn.e_div_v("feat", "norm", "feat"))
for channel in range(num_ew_channels):
graph.edata["feat_" + str(channel)] = graph.edata["feat"][:, channel : channel + 1]
graph.edata["feat_" + str(channel)] = graph.edata["feat"][
:, channel : channel + 1
]
def run_a_train_epoch(graph, node_idx, model, criterion, optimizer, evaluator):
......@@ -50,9 +52,24 @@ def run_an_eval_epoch(graph, splitted_idx, model, evaluator):
labels = graph.ndata["labels"].cpu().numpy()
preds = logits.cpu().detach().numpy()
train_score = evaluator.eval({"y_true": labels[splitted_idx["train"]], "y_pred": preds[splitted_idx["train"]]})
val_score = evaluator.eval({"y_true": labels[splitted_idx["valid"]], "y_pred": preds[splitted_idx["valid"]]})
test_score = evaluator.eval({"y_true": labels[splitted_idx["test"]], "y_pred": preds[splitted_idx["test"]]})
train_score = evaluator.eval(
{
"y_true": labels[splitted_idx["train"]],
"y_pred": preds[splitted_idx["train"]],
}
)
val_score = evaluator.eval(
{
"y_true": labels[splitted_idx["valid"]],
"y_pred": preds[splitted_idx["valid"]],
}
)
test_score = evaluator.eval(
{
"y_true": labels[splitted_idx["test"]],
"y_pred": preds[splitted_idx["test"]],
}
)
return train_score["rocauc"], val_score["rocauc"], test_score["rocauc"]
......@@ -75,12 +92,18 @@ def main(args):
elif args["ewnorm"] == "none":
print("Not normalizing edge weights")
for channel in range(args["num_ew_channels"]):
graph.edata["feat_" + str(channel)] = graph.edata["feat"][:, channel : channel + 1]
graph.edata["feat_" + str(channel)] = graph.edata["feat"][
:, channel : channel + 1
]
model = load_model(args).to(args["device"])
optimizer = Adam(model.parameters(), lr=args["lr"], weight_decay=args["weight_decay"])
optimizer = Adam(
model.parameters(), lr=args["lr"], weight_decay=args["weight_decay"]
)
min_lr = 1e-3
scheduler = ReduceLROnPlateau(optimizer, "max", factor=0.7, patience=100, verbose=True, min_lr=min_lr)
scheduler = ReduceLROnPlateau(
optimizer, "max", factor=0.7, patience=100, verbose=True, min_lr=min_lr
)
print("scheduler min_lr", min_lr)
criterion = nn.BCEWithLogitsLoss()
......@@ -95,7 +118,9 @@ def main(args):
best_val_score = 0.0
num_patient_epochs = 0
model_folder = "./saved_models/"
model_path = model_folder + str(args["exp_name"]) + "_" + str(args["postfix"])
model_path = (
model_folder + str(args["exp_name"]) + "_" + str(args["postfix"])
)
if not os.path.exists(model_folder):
os.makedirs(model_folder)
......@@ -104,7 +129,9 @@ def main(args):
if epoch >= 3:
t0 = time.time()
loss, train_score = run_a_train_epoch(graph, splitted_idx["train"], model, criterion, optimizer, evaluator)
loss, train_score = run_a_train_epoch(
graph, splitted_idx["train"], model, criterion, optimizer, evaluator
)
if epoch >= 3:
dur.append(time.time() - t0)
......@@ -112,7 +139,9 @@ def main(args):
else:
avg_time = None
train_score, val_score, test_score = run_an_eval_epoch(graph, splitted_idx, model, evaluator)
train_score, val_score, test_score = run_an_eval_epoch(
graph, splitted_idx, model, evaluator
)
scheduler.step(val_score)
......@@ -127,7 +156,12 @@ def main(args):
print(
"Epoch {:d}, loss {:.4f}, train score {:.4f}, "
"val score {:.4f}, avg time {}, num patient epochs {:d}".format(
epoch, loss, train_score, val_score, avg_time, num_patient_epochs
epoch,
loss,
train_score,
val_score,
avg_time,
num_patient_epochs,
)
)
......@@ -135,7 +169,9 @@ def main(args):
break
model.load_state_dict(torch.load(model_path))
train_score, val_score, test_score = run_an_eval_epoch(graph, splitted_idx, model, evaluator)
train_score, val_score, test_score = run_an_eval_epoch(
graph, splitted_idx, model, evaluator
)
print("Train score {:.4f}".format(train_score))
print("Valid score {:.4f}".format(val_score))
print("Test score {:.4f}".format(test_score))
......@@ -153,15 +189,34 @@ if __name__ == "__main__":
from configure import get_exp_configure
parser = argparse.ArgumentParser(description="OGB node property prediction with DGL using full graph training")
parser = argparse.ArgumentParser(
description="OGB node property prediction with DGL using full graph training"
)
parser.add_argument(
"-m", "--model", type=str, choices=["MWE-GCN", "MWE-DGCN"], default="MWE-DGCN", help="Model to use"
"-m",
"--model",
type=str,
choices=["MWE-GCN", "MWE-DGCN"],
default="MWE-DGCN",
help="Model to use",
)
parser.add_argument("-c", "--cuda", type=str, default="none")
parser.add_argument("--postfix", type=str, default="", help="a string appended to the file name of the saved model")
parser.add_argument("--rand_seed", type=int, default=-1, help="random seed for torch and numpy")
parser.add_argument(
"--postfix",
type=str,
default="",
help="a string appended to the file name of the saved model",
)
parser.add_argument(
"--rand_seed",
type=int,
default=-1,
help="random seed for torch and numpy",
)
parser.add_argument("--residual", action="store_true")
parser.add_argument("--ewnorm", type=str, default="none", choices=["none", "both"])
parser.add_argument(
"--ewnorm", type=str, default="none", choices=["none", "both"]
)
args = parser.parse_args().__dict__
# Get experiment configuration
......
......@@ -3,7 +3,6 @@ import random
import numpy as np
import torch
import torch.nn.functional as F
from models import MWE_DGCN, MWE_GCN
......@@ -87,4 +86,3 @@ class Logger(object):
print(f" Final Train: {r.mean():.2f} ± {r.std():.2f}")
r = best_result[:, 3]
print(f" Final Test: {r.mean():.2f} ± {r.std():.2f}")
import argparse
import time
import os
import sys
import math
import os
import random
from tqdm import tqdm
import sys
import time
import numpy as np
import torch
from torch.nn import ModuleList, Linear, Conv1d, MaxPool1d, Embedding, BCEWithLogitsLoss
import torch.nn.functional as F
from ogb.linkproppred import DglLinkPropPredDataset, Evaluator
from scipy.sparse.csgraph import shortest_path
from torch.nn import (
BCEWithLogitsLoss,
Conv1d,
Embedding,
Linear,
MaxPool1d,
ModuleList,
)
from tqdm import tqdm
import dgl
from dgl.dataloading import DataLoader, Sampler
from dgl.nn import GraphConv, SortPooling
from dgl.sampling import global_uniform_negative_sampling
from dgl.dataloading import Sampler, DataLoader
from ogb.linkproppred import DglLinkPropPredDataset, Evaluator
from scipy.sparse.csgraph import shortest_path
class Logger(object):
......@@ -32,10 +41,10 @@ class Logger(object):
if run is not None:
result = 100 * torch.tensor(self.results[run])
argmax = result[:, 0].argmax().item()
print(f'Run {run + 1:02d}:', file=f)
print(f'Highest Valid: {result[:, 0].max():.2f}', file=f)
print(f'Highest Eval Point: {argmax + 1}', file=f)
print(f' Final Test: {result[argmax, 1]:.2f}', file=f)
print(f"Run {run + 1:02d}:", file=f)
print(f"Highest Valid: {result[:, 0].max():.2f}", file=f)
print(f"Highest Eval Point: {argmax + 1}", file=f)
print(f" Final Test: {result[argmax, 1]:.2f}", file=f)
else:
result = 100 * torch.tensor(self.results)
......@@ -47,16 +56,23 @@ class Logger(object):
best_result = torch.tensor(best_results)
print(f'All runs:', file=f)
print(f"All runs:", file=f)
r = best_result[:, 0]
print(f'Highest Valid: {r.mean():.2f} ± {r.std():.2f}', file=f)
print(f"Highest Valid: {r.mean():.2f} ± {r.std():.2f}", file=f)
r = best_result[:, 1]
print(f' Final Test: {r.mean():.2f} ± {r.std():.2f}', file=f)
print(f" Final Test: {r.mean():.2f} ± {r.std():.2f}", file=f)
class SealSampler(Sampler):
def __init__(self, g, num_hops=1, sample_ratio=1., directed=False,
prefetch_node_feats=None, prefetch_edge_feats=None):
def __init__(
self,
g,
num_hops=1,
sample_ratio=1.0,
directed=False,
prefetch_node_feats=None,
prefetch_edge_feats=None,
):
super().__init__()
self.g = g
self.num_hops = num_hops
......@@ -71,22 +87,29 @@ class SealSampler(Sampler):
idx = list(range(1)) + list(range(2, N))
adj_wo_dst = adj[idx, :][:, idx]
dist2src = shortest_path(adj_wo_dst, directed=False, unweighted=True, indices=0)
dist2src = shortest_path(
adj_wo_dst, directed=False, unweighted=True, indices=0
)
dist2src = np.insert(dist2src, 1, 0, axis=0)
dist2src = torch.from_numpy(dist2src)
dist2dst = shortest_path(adj_wo_src, directed=False, unweighted=True, indices=0)
dist2dst = shortest_path(
adj_wo_src, directed=False, unweighted=True, indices=0
)
dist2dst = np.insert(dist2dst, 0, 0, axis=0)
dist2dst = torch.from_numpy(dist2dst)
dist = dist2src + dist2dst
dist_over_2, dist_mod_2 = torch.div(dist, 2, rounding_mode='floor'), dist % 2
dist_over_2, dist_mod_2 = (
torch.div(dist, 2, rounding_mode="floor"),
dist % 2,
)
z = 1 + torch.min(dist2src, dist2dst)
z += dist_over_2 * (dist_over_2 + dist_mod_2 - 1)
z[0: 2] = 1.
z[0:2] = 1.0
# shortest path may include inf values
z[torch.isnan(z)] = 0.
z[torch.isnan(z)] = 0.0
return z.to(torch.long)
......@@ -107,9 +130,12 @@ class SealSampler(Sampler):
fringe = np.union1d(in_neighbors, out_neighbors)
fringe = np.setdiff1d(fringe, visited)
visited = np.union1d(visited, fringe)
if self.sample_ratio < 1.:
fringe = np.random.choice(fringe,
int(self.sample_ratio * len(fringe)), replace=False)
if self.sample_ratio < 1.0:
fringe = np.random.choice(
fringe,
int(self.sample_ratio * len(fringe)),
replace=False,
)
if len(fringe) == 0:
break
nodes = np.union1d(nodes, fringe)
......@@ -117,26 +143,34 @@ class SealSampler(Sampler):
# remove edges to predict
edges_to_remove = [
subg.edge_ids(s, t) for s, t in [(0, 1), (1, 0)] if subg.has_edges_between(s, t)]
subg.edge_ids(s, t)
for s, t in [(0, 1), (1, 0)]
if subg.has_edges_between(s, t)
]
subg.remove_edges(edges_to_remove)
# add double radius node labeling
subg.ndata['z'] = self._double_radius_node_labeling(subg.adj(scipy_fmt='csr'))
subg.ndata["z"] = self._double_radius_node_labeling(
subg.adj(scipy_fmt="csr")
)
subg_aug = subg.add_self_loop()
if 'weight' in subg.edata:
subg_aug.edata['weight'][subg.num_edges():] = torch.ones(
subg_aug.num_edges() - subg.num_edges())
if "weight" in subg.edata:
subg_aug.edata["weight"][subg.num_edges() :] = torch.ones(
subg_aug.num_edges() - subg.num_edges()
)
subgraphs.append(subg_aug)
subgraphs = dgl.batch(subgraphs)
dgl.set_src_lazy_features(subg_aug, self.prefetch_node_feats)
dgl.set_edge_lazy_features(subg_aug, self.prefetch_edge_feats)
return subgraphs, aug_g.edata['y'][seed_edges]
return subgraphs, aug_g.edata["y"][seed_edges]
# An end-to-end deep learning architecture for graph classification, AAAI-18.
class DGCNN(torch.nn.Module):
def __init__(self, hidden_channels, num_layers, k, GNN=GraphConv, feature_dim=0):
def __init__(
self, hidden_channels, num_layers, k, GNN=GraphConv, feature_dim=0
):
super(DGCNN, self).__init__()
self.feature_dim = feature_dim
self.k = k
......@@ -149,18 +183,18 @@ class DGCNN(torch.nn.Module):
initial_channels = hidden_channels + self.feature_dim
self.convs.append(GNN(initial_channels, hidden_channels))
for _ in range(0, num_layers-1):
for _ in range(0, num_layers - 1):
self.convs.append(GNN(hidden_channels, hidden_channels))
self.convs.append(GNN(hidden_channels, 1))
conv1d_channels = [16, 32]
total_latent_dim = hidden_channels * num_layers + 1
conv1d_kws = [total_latent_dim, 5]
self.conv1 = Conv1d(1, conv1d_channels[0], conv1d_kws[0],
conv1d_kws[0])
self.conv1 = Conv1d(1, conv1d_channels[0], conv1d_kws[0], conv1d_kws[0])
self.maxpool1d = MaxPool1d(2, 2)
self.conv2 = Conv1d(conv1d_channels[0], conv1d_channels[1],
conv1d_kws[1], 1)
self.conv2 = Conv1d(
conv1d_channels[0], conv1d_channels[1], conv1d_kws[1], 1
)
dense_dim = int((self.k - 2) / 2 + 1)
dense_dim = (dense_dim - conv1d_kws[1] + 1) * conv1d_channels[1]
self.lin1 = Linear(dense_dim, 128)
......@@ -196,33 +230,35 @@ class DGCNN(torch.nn.Module):
def get_pos_neg_edges(split, split_edge, g, percent=100):
pos_edge = split_edge[split]['edge']
if split == 'train':
neg_edge = torch.stack(global_uniform_negative_sampling(
g, num_samples=pos_edge.size(0),
exclude_self_loops=True
), dim=1)
pos_edge = split_edge[split]["edge"]
if split == "train":
neg_edge = torch.stack(
global_uniform_negative_sampling(
g, num_samples=pos_edge.size(0), exclude_self_loops=True
),
dim=1,
)
else:
neg_edge = split_edge[split]['edge_neg']
neg_edge = split_edge[split]["edge_neg"]
# sampling according to the percent param
np.random.seed(123)
# pos sampling
num_pos = pos_edge.size(0)
perm = np.random.permutation(num_pos)
perm = perm[:int(percent / 100 * num_pos)]
perm = perm[: int(percent / 100 * num_pos)]
pos_edge = pos_edge[perm]
# neg sampling
if neg_edge.dim() > 2: # [Np, Nn, 2]
if neg_edge.dim() > 2: # [Np, Nn, 2]
neg_edge = neg_edge[perm].view(-1, 2)
else:
np.random.seed(123)
num_neg = neg_edge.size(0)
perm = np.random.permutation(num_neg)
perm = perm[:int(percent / 100 * num_neg)]
perm = perm[: int(percent / 100 * num_neg)]
neg_edge = neg_edge[perm]
return pos_edge, neg_edge # ([2, Np], [2, Nn]) -> ([Np, 2], [Nn, 2])
return pos_edge, neg_edge # ([2, Np], [2, Nn]) -> ([Np, 2], [Nn, 2])
def train():
......@@ -233,8 +269,12 @@ def train():
pbar = tqdm(train_loader, ncols=70)
for gs, y in pbar:
optimizer.zero_grad()
logits = model(gs, gs.ndata['z'], gs.ndata.get('feat', None),
edge_weight=gs.edata.get('weight', None))
logits = model(
gs,
gs.ndata["z"],
gs.ndata.get("feat", None),
edge_weight=gs.edata.get("weight", None),
)
loss = loss_fnt(logits.view(-1), y.to(torch.float))
loss.backward()
optimizer.step()
......@@ -250,28 +290,40 @@ def test():
y_pred, y_true = [], []
for gs, y in tqdm(val_loader, ncols=70):
logits = model(gs, gs.ndata['z'], gs.ndata.get('feat', None),
edge_weight=gs.edata.get('weight', None))
logits = model(
gs,
gs.ndata["z"],
gs.ndata.get("feat", None),
edge_weight=gs.edata.get("weight", None),
)
y_pred.append(logits.view(-1).cpu())
y_true.append(y.view(-1).cpu().to(torch.float))
val_pred, val_true = torch.cat(y_pred), torch.cat(y_true)
pos_val_pred = val_pred[val_true==1]
neg_val_pred = val_pred[val_true==0]
pos_val_pred = val_pred[val_true == 1]
neg_val_pred = val_pred[val_true == 0]
y_pred, y_true = [], []
for gs, y in tqdm(test_loader, ncols=70):
logits = model(gs, gs.ndata['z'], gs.ndata.get('feat', None),
edge_weight=gs.edata.get('weight', None))
logits = model(
gs,
gs.ndata["z"],
gs.ndata.get("feat", None),
edge_weight=gs.edata.get("weight", None),
)
y_pred.append(logits.view(-1).cpu())
y_true.append(y.view(-1).cpu().to(torch.float))
test_pred, test_true = torch.cat(y_pred), torch.cat(y_true)
pos_test_pred = test_pred[test_true==1]
neg_test_pred = test_pred[test_true==0]
if args.eval_metric == 'hits':
results = evaluate_hits(pos_val_pred, neg_val_pred, pos_test_pred, neg_test_pred)
elif args.eval_metric == 'mrr':
results = evaluate_mrr(pos_val_pred, neg_val_pred, pos_test_pred, neg_test_pred)
pos_test_pred = test_pred[test_true == 1]
neg_test_pred = test_pred[test_true == 0]
if args.eval_metric == "hits":
results = evaluate_hits(
pos_val_pred, neg_val_pred, pos_test_pred, neg_test_pred
)
elif args.eval_metric == "mrr":
results = evaluate_mrr(
pos_val_pred, neg_val_pred, pos_test_pred, neg_test_pred
)
return results
......@@ -280,184 +332,254 @@ def evaluate_hits(pos_val_pred, neg_val_pred, pos_test_pred, neg_test_pred):
results = {}
for K in [20, 50, 100]:
evaluator.K = K
valid_hits = evaluator.eval({
'y_pred_pos': pos_val_pred,
'y_pred_neg': neg_val_pred,
})[f'hits@{K}']
test_hits = evaluator.eval({
'y_pred_pos': pos_test_pred,
'y_pred_neg': neg_test_pred,
})[f'hits@{K}']
results[f'Hits@{K}'] = (valid_hits, test_hits)
valid_hits = evaluator.eval(
{
"y_pred_pos": pos_val_pred,
"y_pred_neg": neg_val_pred,
}
)[f"hits@{K}"]
test_hits = evaluator.eval(
{
"y_pred_pos": pos_test_pred,
"y_pred_neg": neg_test_pred,
}
)[f"hits@{K}"]
results[f"Hits@{K}"] = (valid_hits, test_hits)
return results
def evaluate_mrr(pos_val_pred, neg_val_pred, pos_test_pred, neg_test_pred):
print(pos_val_pred.size(), neg_val_pred.size(), pos_test_pred.size(), neg_test_pred.size())
print(
pos_val_pred.size(),
neg_val_pred.size(),
pos_test_pred.size(),
neg_test_pred.size(),
)
neg_val_pred = neg_val_pred.view(pos_val_pred.shape[0], -1)
neg_test_pred = neg_test_pred.view(pos_test_pred.shape[0], -1)
results = {}
valid_mrr = evaluator.eval({
'y_pred_pos': pos_val_pred,
'y_pred_neg': neg_val_pred,
})['mrr_list'].mean().item()
test_mrr = evaluator.eval({
'y_pred_pos': pos_test_pred,
'y_pred_neg': neg_test_pred,
})['mrr_list'].mean().item()
results['MRR'] = (valid_mrr, test_mrr)
valid_mrr = (
evaluator.eval(
{
"y_pred_pos": pos_val_pred,
"y_pred_neg": neg_val_pred,
}
)["mrr_list"]
.mean()
.item()
)
test_mrr = (
evaluator.eval(
{
"y_pred_pos": pos_test_pred,
"y_pred_neg": neg_test_pred,
}
)["mrr_list"]
.mean()
.item()
)
results["MRR"] = (valid_mrr, test_mrr)
return results
if __name__ == '__main__':
if __name__ == "__main__":
# Data settings
parser = argparse.ArgumentParser(description='OGBL (SEAL)')
parser.add_argument('--dataset', type=str, default='ogbl-collab')
parser = argparse.ArgumentParser(description="OGBL (SEAL)")
parser.add_argument("--dataset", type=str, default="ogbl-collab")
# GNN settings
parser.add_argument('--sortpool_k', type=float, default=0.6)
parser.add_argument('--num_layers', type=int, default=3)
parser.add_argument('--hidden_channels', type=int, default=32)
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument("--sortpool_k", type=float, default=0.6)
parser.add_argument("--num_layers", type=int, default=3)
parser.add_argument("--hidden_channels", type=int, default=32)
parser.add_argument("--batch_size", type=int, default=32)
# Subgraph extraction settings
parser.add_argument('--ratio_per_hop', type=float, default=1.0)
parser.add_argument('--use_feature', action='store_true',
help="whether to use raw node features as GNN input")
parser.add_argument('--use_edge_weight', action='store_true',
help="whether to consider edge weight in GNN")
parser.add_argument("--ratio_per_hop", type=float, default=1.0)
parser.add_argument(
"--use_feature",
action="store_true",
help="whether to use raw node features as GNN input",
)
parser.add_argument(
"--use_edge_weight",
action="store_true",
help="whether to consider edge weight in GNN",
)
# Training settings
parser.add_argument('--lr', type=float, default=0.0001)
parser.add_argument('--epochs', type=int, default=50)
parser.add_argument('--runs', type=int, default=10)
parser.add_argument('--train_percent', type=float, default=100)
parser.add_argument('--val_percent', type=float, default=100)
parser.add_argument('--test_percent', type=float, default=100)
parser.add_argument('--num_workers', type=int, default=8,
help="number of workers for dynamic dataloaders")
parser.add_argument("--lr", type=float, default=0.0001)
parser.add_argument("--epochs", type=int, default=50)
parser.add_argument("--runs", type=int, default=10)
parser.add_argument("--train_percent", type=float, default=100)
parser.add_argument("--val_percent", type=float, default=100)
parser.add_argument("--test_percent", type=float, default=100)
parser.add_argument(
"--num_workers",
type=int,
default=8,
help="number of workers for dynamic dataloaders",
)
# Testing settings
parser.add_argument('--use_valedges_as_input', action='store_true')
parser.add_argument('--eval_steps', type=int, default=1)
parser.add_argument("--use_valedges_as_input", action="store_true")
parser.add_argument("--eval_steps", type=int, default=1)
args = parser.parse_args()
data_appendix = '_rph{}'.format(''.join(str(args.ratio_per_hop).split('.')))
data_appendix = "_rph{}".format("".join(str(args.ratio_per_hop).split(".")))
if args.use_valedges_as_input:
data_appendix += '_uvai'
data_appendix += "_uvai"
args.res_dir = os.path.join('results/{}_{}'.format(args.dataset,
time.strftime("%Y%m%d%H%M%S")))
print('Results will be saved in ' + args.res_dir)
args.res_dir = os.path.join(
"results/{}_{}".format(args.dataset, time.strftime("%Y%m%d%H%M%S"))
)
print("Results will be saved in " + args.res_dir)
if not os.path.exists(args.res_dir):
os.makedirs(args.res_dir)
log_file = os.path.join(args.res_dir, 'log.txt')
os.makedirs(args.res_dir)
log_file = os.path.join(args.res_dir, "log.txt")
# Save command line input.
cmd_input = 'python ' + ' '.join(sys.argv) + '\n'
with open(os.path.join(args.res_dir, 'cmd_input.txt'), 'a') as f:
cmd_input = "python " + " ".join(sys.argv) + "\n"
with open(os.path.join(args.res_dir, "cmd_input.txt"), "a") as f:
f.write(cmd_input)
print('Command line input: ' + cmd_input + ' is saved.')
with open(log_file, 'a') as f:
f.write('\n' + cmd_input)
print("Command line input: " + cmd_input + " is saved.")
with open(log_file, "a") as f:
f.write("\n" + cmd_input)
dataset = DglLinkPropPredDataset(name=args.dataset)
split_edge = dataset.get_edge_split()
graph = dataset[0]
# re-format the data of citation2
if args.dataset == 'ogbl-citation2':
for k in ['train', 'valid', 'test']:
src = split_edge[k]['source_node']
tgt = split_edge[k]['target_node']
split_edge[k]['edge'] = torch.stack([src, tgt], dim=1)
if k != 'train':
tgt_neg = split_edge[k]['target_node_neg']
split_edge[k]['edge_neg'] = torch.stack([
src[:, None].repeat(1, tgt_neg.size(1)),
tgt_neg
], dim=-1) # [Ns, Nt, 2]
if args.dataset == "ogbl-citation2":
for k in ["train", "valid", "test"]:
src = split_edge[k]["source_node"]
tgt = split_edge[k]["target_node"]
split_edge[k]["edge"] = torch.stack([src, tgt], dim=1)
if k != "train":
tgt_neg = split_edge[k]["target_node_neg"]
split_edge[k]["edge_neg"] = torch.stack(
[src[:, None].repeat(1, tgt_neg.size(1)), tgt_neg], dim=-1
) # [Ns, Nt, 2]
# reconstruct the graph for ogbl-collab data for validation edge augmentation and coalesce
if args.dataset == 'ogbl-collab':
if args.dataset == "ogbl-collab":
if args.use_valedges_as_input:
val_edges = split_edge['valid']['edge']
val_edges = split_edge["valid"]["edge"]
row, col = val_edges.t()
# float edata for to_simple transform
graph.edata.pop('year')
graph.edata['weight'] = graph.edata['weight'].to(torch.float)
graph.edata.pop("year")
graph.edata["weight"] = graph.edata["weight"].to(torch.float)
val_weights = torch.ones(size=(val_edges.size(0), 1))
graph.add_edges(torch.cat([row, col]), torch.cat([col, row]), {'weight': val_weights})
graph = graph.to_simple(copy_edata=True, aggregator='sum')
if not args.use_edge_weight and 'weight' in graph.edata:
graph.edata.pop('weight')
if not args.use_feature and 'feat' in graph.ndata:
graph.ndata.pop('feat')
if args.dataset.startswith('ogbl-citation'):
args.eval_metric = 'mrr'
graph.add_edges(
torch.cat([row, col]),
torch.cat([col, row]),
{"weight": val_weights},
)
graph = graph.to_simple(copy_edata=True, aggregator="sum")
if not args.use_edge_weight and "weight" in graph.edata:
graph.edata.pop("weight")
if not args.use_feature and "feat" in graph.ndata:
graph.ndata.pop("feat")
if args.dataset.startswith("ogbl-citation"):
args.eval_metric = "mrr"
directed = True
else:
args.eval_metric = 'hits'
args.eval_metric = "hits"
directed = False
evaluator = Evaluator(name=args.dataset)
if args.eval_metric == 'hits':
if args.eval_metric == "hits":
loggers = {
'Hits@20': Logger(args.runs, args),
'Hits@50': Logger(args.runs, args),
'Hits@100': Logger(args.runs, args),
"Hits@20": Logger(args.runs, args),
"Hits@50": Logger(args.runs, args),
"Hits@100": Logger(args.runs, args),
}
elif args.eval_metric == 'mrr':
elif args.eval_metric == "mrr":
loggers = {
'MRR': Logger(args.runs, args),
"MRR": Logger(args.runs, args),
}
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
path = dataset.root + '_seal{}'.format(data_appendix)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
path = dataset.root + "_seal{}".format(data_appendix)
loaders = []
prefetch_node_feats = ['feat'] if 'feat' in graph.ndata else None
prefetch_edge_feats = ['weight'] if 'weight' in graph.edata else None
train_edge, train_edge_neg = get_pos_neg_edges('train', split_edge, graph, args.train_percent)
val_edge, val_edge_neg = get_pos_neg_edges('valid', split_edge, graph, args.val_percent)
test_edge, test_edge_neg = get_pos_neg_edges('test', split_edge, graph, args.test_percent)
prefetch_node_feats = ["feat"] if "feat" in graph.ndata else None
prefetch_edge_feats = ["weight"] if "weight" in graph.edata else None
train_edge, train_edge_neg = get_pos_neg_edges(
"train", split_edge, graph, args.train_percent
)
val_edge, val_edge_neg = get_pos_neg_edges(
"valid", split_edge, graph, args.val_percent
)
test_edge, test_edge_neg = get_pos_neg_edges(
"test", split_edge, graph, args.test_percent
)
# create an augmented graph for sampling
aug_g = dgl.graph(graph.edges())
aug_g.edata['y'] = torch.ones(aug_g.num_edges())
aug_edges = torch.cat([val_edge, test_edge, train_edge_neg, val_edge_neg, test_edge_neg])
aug_labels = torch.cat([
torch.ones(len(val_edge) + len(test_edge)),
torch.zeros(len(train_edge_neg) + len(val_edge_neg) + len(test_edge_neg))
])
aug_g.add_edges(aug_edges[:, 0], aug_edges[:, 1], {'y': aug_labels})
aug_g.edata["y"] = torch.ones(aug_g.num_edges())
aug_edges = torch.cat(
[val_edge, test_edge, train_edge_neg, val_edge_neg, test_edge_neg]
)
aug_labels = torch.cat(
[
torch.ones(len(val_edge) + len(test_edge)),
torch.zeros(
len(train_edge_neg) + len(val_edge_neg) + len(test_edge_neg)
),
]
)
aug_g.add_edges(aug_edges[:, 0], aug_edges[:, 1], {"y": aug_labels})
# eids for sampling
split_len = [graph.num_edges()] + \
list(map(len, [val_edge, test_edge, train_edge_neg, val_edge_neg, test_edge_neg]))
train_eids = torch.cat([
graph.edge_ids(train_edge[:, 0], train_edge[:, 1]),
torch.arange(sum(split_len[:3]), sum(split_len[:4]))
])
val_eids = torch.cat([
torch.arange(sum(split_len[:1]), sum(split_len[:2])),
torch.arange(sum(split_len[:4]), sum(split_len[:5]))
])
test_eids = torch.cat([
torch.arange(sum(split_len[:2]), sum(split_len[:3])),
torch.arange(sum(split_len[:5]), sum(split_len[:6]))
])
sampler = SealSampler(graph, 1, args.ratio_per_hop, directed,
prefetch_node_feats, prefetch_edge_feats)
split_len = [graph.num_edges()] + list(
map(
len,
[val_edge, test_edge, train_edge_neg, val_edge_neg, test_edge_neg],
)
)
train_eids = torch.cat(
[
graph.edge_ids(train_edge[:, 0], train_edge[:, 1]),
torch.arange(sum(split_len[:3]), sum(split_len[:4])),
]
)
val_eids = torch.cat(
[
torch.arange(sum(split_len[:1]), sum(split_len[:2])),
torch.arange(sum(split_len[:4]), sum(split_len[:5])),
]
)
test_eids = torch.cat(
[
torch.arange(sum(split_len[:2]), sum(split_len[:3])),
torch.arange(sum(split_len[:5]), sum(split_len[:6])),
]
)
sampler = SealSampler(
graph,
1,
args.ratio_per_hop,
directed,
prefetch_node_feats,
prefetch_edge_feats,
)
# force to be dynamic for consistent dataloading
for split, shuffle, eids in zip(
['train', 'valid', 'test'],
["train", "valid", "test"],
[True, False, False],
[train_eids, val_eids, test_eids]
[train_eids, val_eids, test_eids],
):
data_loader = DataLoader(aug_g, eids, sampler, shuffle=shuffle, device=device,
batch_size=args.batch_size, num_workers=args.num_workers)
data_loader = DataLoader(
aug_g,
eids,
sampler,
shuffle=shuffle,
device=device,
batch_size=args.batch_size,
num_workers=args.num_workers,
)
loaders.append(data_loader)
train_loader, val_loader, test_loader = loaders
......@@ -474,16 +596,20 @@ if __name__ == '__main__':
k = max(k, 10)
for run in range(args.runs):
model = DGCNN(args.hidden_channels, args.num_layers, k,
feature_dim=graph.ndata['feat'].size(1) if args.use_feature else 0).to(device)
model = DGCNN(
args.hidden_channels,
args.num_layers,
k,
feature_dim=graph.ndata["feat"].size(1) if args.use_feature else 0,
).to(device)
parameters = list(model.parameters())
optimizer = torch.optim.Adam(params=parameters, lr=args.lr)
total_params = sum(p.numel() for param in parameters for p in param)
print(f'Total number of parameters is {total_params}')
print(f'SortPooling k is set to {k}')
with open(log_file, 'a') as f:
print(f'Total number of parameters is {total_params}', file=f)
print(f'SortPooling k is set to {k}', file=f)
print(f"Total number of parameters is {total_params}")
print(f"SortPooling k is set to {k}")
with open(log_file, "a") as f:
print(f"Total number of parameters is {total_params}", file=f)
print(f"SortPooling k is set to {k}", file=f)
start_epoch = 1
# Training starts
......@@ -496,35 +622,41 @@ if __name__ == '__main__':
loggers[key].add_result(run, result)
model_name = os.path.join(
args.res_dir, 'run{}_model_checkpoint{}.pth'.format(run+1, epoch))
args.res_dir,
"run{}_model_checkpoint{}.pth".format(run + 1, epoch),
)
optimizer_name = os.path.join(
args.res_dir, 'run{}_optimizer_checkpoint{}.pth'.format(run+1, epoch))
args.res_dir,
"run{}_optimizer_checkpoint{}.pth".format(run + 1, epoch),
)
torch.save(model.state_dict(), model_name)
torch.save(optimizer.state_dict(), optimizer_name)
for key, result in results.items():
valid_res, test_res = result
to_print = (f'Run: {run + 1:02d}, Epoch: {epoch:02d}, ' +
f'Loss: {loss:.4f}, Valid: {100 * valid_res:.2f}%, ' +
f'Test: {100 * test_res:.2f}%')
to_print = (
f"Run: {run + 1:02d}, Epoch: {epoch:02d}, "
+ f"Loss: {loss:.4f}, Valid: {100 * valid_res:.2f}%, "
+ f"Test: {100 * test_res:.2f}%"
)
print(key)
print(to_print)
with open(log_file, 'a') as f:
with open(log_file, "a") as f:
print(key, file=f)
print(to_print, file=f)
for key in loggers.keys():
print(key)
loggers[key].print_statistics(run)
with open(log_file, 'a') as f:
with open(log_file, "a") as f:
print(key, file=f)
loggers[key].print_statistics(run, f=f)
for key in loggers.keys():
print(key)
loggers[key].print_statistics()
with open(log_file, 'a') as f:
with open(log_file, "a") as f:
print(key, file=f)
loggers[key].print_statistics(f=f)
print(f'Total number of parameters is {total_params}')
print(f'Results are saved in {args.res_dir}')
print(f"Total number of parameters is {total_params}")
print(f"Results are saved in {args.res_dir}")
import torch
import numpy as np
import torch
from ogb.nodeproppred import DglNodePropPredDataset, Evaluator
import dgl
import dgl.function as fn
from ogb.nodeproppred import DglNodePropPredDataset, Evaluator
def get_ogb_evaluator(dataset):
......@@ -10,10 +11,12 @@ def get_ogb_evaluator(dataset):
Get evaluator from Open Graph Benchmark based on dataset
"""
evaluator = Evaluator(name=dataset)
return lambda preds, labels: evaluator.eval({
"y_true": labels.view(-1, 1),
"y_pred": preds.view(-1, 1),
})["acc"]
return lambda preds, labels: evaluator.eval(
{
"y_true": labels.view(-1, 1),
"y_pred": preds.view(-1, 1),
}
)["acc"]
def convert_mag_to_homograph(g, device):
......@@ -25,11 +28,13 @@ def convert_mag_to_homograph(g, device):
src_writes, dst_writes = g.all_edges(etype="writes")
src_topic, dst_topic = g.all_edges(etype="has_topic")
src_aff, dst_aff = g.all_edges(etype="affiliated_with")
new_g = dgl.heterograph({
("paper", "written", "author"): (dst_writes, src_writes),
("paper", "has_topic", "field"): (src_topic, dst_topic),
("author", "aff", "inst"): (src_aff, dst_aff)
})
new_g = dgl.heterograph(
{
("paper", "written", "author"): (dst_writes, src_writes),
("paper", "has_topic", "field"): (src_topic, dst_topic),
("author", "aff", "inst"): (src_aff, dst_aff),
}
)
new_g = new_g.to(device)
new_g.nodes["paper"].data["feat"] = g.nodes["paper"].data["feat"]
new_g["written"].update_all(fn.copy_u("feat", "m"), fn.mean("m", "feat"))
......@@ -65,7 +70,7 @@ def load_dataset(name, device):
if name == "ogbn-arxiv":
g = dgl.add_reverse_edges(g, copy_ndata=True)
g = dgl.add_self_loop(g)
g.ndata['feat'] = g.ndata['feat'].float()
g.ndata["feat"] = g.ndata["feat"].float()
elif name == "ogbn-mag":
# MAG is a heterogeneous graph. The task is to make prediction for
# paper nodes
......@@ -75,16 +80,18 @@ def load_dataset(name, device):
test_nid = test_nid["paper"]
g = convert_mag_to_homograph(g, device)
else:
g.ndata['feat'] = g.ndata['feat'].float()
g.ndata["feat"] = g.ndata["feat"].float()
n_classes = dataset.num_classes
labels = labels.squeeze()
evaluator = get_ogb_evaluator(name)
print(f"# Nodes: {g.number_of_nodes()}\n"
f"# Edges: {g.number_of_edges()}\n"
f"# Train: {len(train_nid)}\n"
f"# Val: {len(val_nid)}\n"
f"# Test: {len(test_nid)}\n"
f"# Classes: {n_classes}")
print(
f"# Nodes: {g.number_of_nodes()}\n"
f"# Edges: {g.number_of_edges()}\n"
f"# Train: {len(train_nid)}\n"
f"# Val: {len(val_nid)}\n"
f"# Test: {len(test_nid)}\n"
f"# Classes: {n_classes}"
)
return g, labels, n_classes, train_nid, val_nid, test_nid, evaluator
import argparse
import time
import numpy as np
import torch
import torch.nn as nn
from dataset import load_dataset
import dgl
import dgl.function as fn
from dataset import load_dataset
class FeedForwardNet(nn.Module):
......@@ -40,8 +42,16 @@ class FeedForwardNet(nn.Module):
class SIGN(nn.Module):
def __init__(self, in_feats, hidden, out_feats, num_hops, n_layers,
dropout, input_drop):
def __init__(
self,
in_feats,
hidden,
out_feats,
num_hops,
n_layers,
dropout,
input_drop,
):
super(SIGN, self).__init__()
self.dropout = nn.Dropout(dropout)
self.prelu = nn.PReLU()
......@@ -49,9 +59,11 @@ class SIGN(nn.Module):
self.input_drop = nn.Dropout(input_drop)
for hop in range(num_hops):
self.inception_ffs.append(
FeedForwardNet(in_feats, hidden, hidden, n_layers, dropout))
self.project = FeedForwardNet(num_hops * hidden, hidden, out_feats,
n_layers, dropout)
FeedForwardNet(in_feats, hidden, hidden, n_layers, dropout)
)
self.project = FeedForwardNet(
num_hops * hidden, hidden, out_feats, n_layers, dropout
)
def forward(self, feats):
feats = [self.input_drop(feat) for feat in feats]
......@@ -72,7 +84,7 @@ def get_n_params(model):
for p in list(model.parameters()):
nn = 1
for s in list(p.size()):
nn = nn*s
nn = nn * s
pp += nn
return pp
......@@ -84,8 +96,9 @@ def neighbor_average_features(g, args):
print("Compute neighbor-averaged feats")
g.ndata["feat_0"] = g.ndata["feat"]
for hop in range(1, args.R + 1):
g.update_all(fn.copy_u(f"feat_{hop-1}", "msg"),
fn.mean("msg", f"feat_{hop}"))
g.update_all(
fn.copy_u(f"feat_{hop-1}", "msg"), fn.mean("msg", f"feat_{hop}")
)
res = []
for hop in range(args.R + 1):
res.append(g.ndata.pop(f"feat_{hop}"))
......@@ -98,8 +111,9 @@ def neighbor_average_features(g, args):
num_target = target_mask.sum().item()
new_res = []
for x in res:
feat = torch.zeros((num_target,) + x.shape[1:],
dtype=x.dtype, device=x.device)
feat = torch.zeros(
(num_target,) + x.shape[1:], dtype=x.dtype, device=x.device
)
feat[target_ids] = x[target_mask]
new_res.append(feat)
res = new_res
......@@ -112,15 +126,23 @@ def prepare_data(device, args):
"""
data = load_dataset(args.dataset, device)
g, labels, n_classes, train_nid, val_nid, test_nid, evaluator = data
in_feats = g.ndata['feat'].shape[1]
in_feats = g.ndata["feat"].shape[1]
feats = neighbor_average_features(g, args)
labels = labels.to(device)
# move to device
train_nid = train_nid.to(device)
val_nid = val_nid.to(device)
test_nid = test_nid.to(device)
return feats, labels, in_feats, n_classes, \
train_nid, val_nid, test_nid, evaluator
return (
feats,
labels,
in_feats,
n_classes,
train_nid,
val_nid,
test_nid,
evaluator,
)
def train(model, feats, labels, loss_fcn, optimizer, train_loader):
......@@ -134,8 +156,9 @@ def train(model, feats, labels, loss_fcn, optimizer, train_loader):
optimizer.step()
def test(model, feats, labels, test_loader, evaluator,
train_nid, val_nid, test_nid):
def test(
model, feats, labels, test_loader, evaluator, train_nid, val_nid, test_nid
):
model.eval()
device = labels.device
preds = []
......@@ -151,24 +174,44 @@ def test(model, feats, labels, test_loader, evaluator,
def run(args, data, device):
feats, labels, in_size, num_classes, \
train_nid, val_nid, test_nid, evaluator = data
(
feats,
labels,
in_size,
num_classes,
train_nid,
val_nid,
test_nid,
evaluator,
) = data
train_loader = torch.utils.data.DataLoader(
train_nid, batch_size=args.batch_size, shuffle=True, drop_last=False)
train_nid, batch_size=args.batch_size, shuffle=True, drop_last=False
)
test_loader = torch.utils.data.DataLoader(
torch.arange(labels.shape[0]), batch_size=args.eval_batch_size,
shuffle=False, drop_last=False)
torch.arange(labels.shape[0]),
batch_size=args.eval_batch_size,
shuffle=False,
drop_last=False,
)
# Initialize model and optimizer for each run
num_hops = args.R + 1
model = SIGN(in_size, args.num_hidden, num_classes, num_hops,
args.ff_layer, args.dropout, args.input_dropout)
model = SIGN(
in_size,
args.num_hidden,
num_classes,
num_hops,
args.ff_layer,
args.dropout,
args.input_dropout,
)
model = model.to(device)
print("# Params:", get_n_params(model))
loss_fcn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr,
weight_decay=args.weight_decay)
optimizer = torch.optim.Adam(
model.parameters(), lr=args.lr, weight_decay=args.weight_decay
)
# Start training
best_epoch = 0
......@@ -180,8 +223,16 @@ def run(args, data, device):
if epoch % args.eval_every == 0:
with torch.no_grad():
acc = test(model, feats, labels, test_loader, evaluator,
train_nid, val_nid, test_nid)
acc = test(
model,
feats,
labels,
test_loader,
evaluator,
train_nid,
val_nid,
test_nid,
)
end = time.time()
log = "Epoch {}, Time(s): {:.4f}, ".format(epoch, end - start)
log += "Acc: Train {:.4f}, Val {:.4f}, Test {:.4f}".format(*acc)
......@@ -191,8 +242,11 @@ def run(args, data, device):
best_val = acc[1]
best_test = acc[2]
print("Best Epoch {}, Val {:.4f}, Test {:.4f}".format(
best_epoch, best_val, best_test))
print(
"Best Epoch {}, Val {:.4f}, Test {:.4f}".format(
best_epoch, best_val, best_test
)
)
return best_val, best_test
......@@ -212,34 +266,51 @@ def main(args):
val_accs.append(best_val)
test_accs.append(best_test)
print(f"Average val accuracy: {np.mean(val_accs):.4f}, "
f"std: {np.std(val_accs):.4f}")
print(f"Average test accuracy: {np.mean(test_accs):.4f}, "
f"std: {np.std(test_accs):.4f}")
print(
f"Average val accuracy: {np.mean(val_accs):.4f}, "
f"std: {np.std(val_accs):.4f}"
)
print(
f"Average test accuracy: {np.mean(test_accs):.4f}, "
f"std: {np.std(test_accs):.4f}"
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="SIGN")
parser.add_argument("--num-epochs", type=int, default=1000)
parser.add_argument("--num-hidden", type=int, default=512)
parser.add_argument("--R", type=int, default=5,
help="number of hops")
parser.add_argument("--R", type=int, default=5, help="number of hops")
parser.add_argument("--lr", type=float, default=0.001)
parser.add_argument("--dataset", type=str, default="ogbn-mag")
parser.add_argument("--dropout", type=float, default=0.5,
help="dropout on activation")
parser.add_argument(
"--dropout", type=float, default=0.5, help="dropout on activation"
)
parser.add_argument("--gpu", type=int, default=0)
parser.add_argument("--weight-decay", type=float, default=0)
parser.add_argument("--eval-every", type=int, default=10)
parser.add_argument("--batch-size", type=int, default=50000)
parser.add_argument("--eval-batch-size", type=int, default=100000,
help="evaluation batch size")
parser.add_argument("--ff-layer", type=int, default=2,
help="number of feed-forward layers")
parser.add_argument("--input-dropout", type=float, default=0,
help="dropout on input features")
parser.add_argument("--num-runs", type=int, default=10,
help="number of times to repeat the experiment")
parser.add_argument(
"--eval-batch-size",
type=int,
default=100000,
help="evaluation batch size",
)
parser.add_argument(
"--ff-layer", type=int, default=2, help="number of feed-forward layers"
)
parser.add_argument(
"--input-dropout",
type=float,
default=0,
help="dropout on input features",
)
parser.add_argument(
"--num-runs",
type=int,
default=10,
help="number of times to repeat the experiment",
)
args = parser.parse_args()
print(args)
......
import ogb
from ogb.lsc import MAG240MDataset
import tqdm
import argparse
import os
import numpy as np
import ogb
import torch
import tqdm
from ogb.lsc import MAG240MDataset
import dgl
import dgl.function as fn
import argparse
import os
parser = argparse.ArgumentParser()
parser.add_argument('--rootdir', type=str, default='.', help='Directory to download the OGB dataset.')
parser.add_argument('--author-output-path', type=str, help='Path to store the author features.')
parser.add_argument('--inst-output-path', type=str,
help='Path to store the institution features.')
parser.add_argument('--graph-output-path', type=str, help='Path to store the graph.')
parser.add_argument('--graph-format', type=str, default='csc', help='Graph format (coo, csr or csc).')
parser.add_argument('--graph-as-homogeneous', action='store_true', help='Store the graph as DGL homogeneous graph.')
parser.add_argument('--full-output-path', type=str,
help='Path to store features of all nodes. Effective only when graph is homogeneous.')
parser.add_argument(
"--rootdir",
type=str,
default=".",
help="Directory to download the OGB dataset.",
)
parser.add_argument(
"--author-output-path", type=str, help="Path to store the author features."
)
parser.add_argument(
"--inst-output-path",
type=str,
help="Path to store the institution features.",
)
parser.add_argument(
"--graph-output-path", type=str, help="Path to store the graph."
)
parser.add_argument(
"--graph-format",
type=str,
default="csc",
help="Graph format (coo, csr or csc).",
)
parser.add_argument(
"--graph-as-homogeneous",
action="store_true",
help="Store the graph as DGL homogeneous graph.",
)
parser.add_argument(
"--full-output-path",
type=str,
help="Path to store features of all nodes. Effective only when graph is homogeneous.",
)
args = parser.parse_args()
print('Building graph')
print("Building graph")
dataset = MAG240MDataset(root=args.rootdir)
ei_writes = dataset.edge_index('author', 'writes', 'paper')
ei_cites = dataset.edge_index('paper', 'paper')
ei_affiliated = dataset.edge_index('author', 'institution')
ei_writes = dataset.edge_index("author", "writes", "paper")
ei_cites = dataset.edge_index("paper", "paper")
ei_affiliated = dataset.edge_index("author", "institution")
# We sort the nodes starting with the papers, then the authors, then the institutions.
author_offset = 0
inst_offset = author_offset + dataset.num_authors
paper_offset = inst_offset + dataset.num_institutions
g = dgl.heterograph({
('author', 'write', 'paper'): (ei_writes[0], ei_writes[1]),
('paper', 'write-by', 'author'): (ei_writes[1], ei_writes[0]),
('author', 'affiliate-with', 'institution'): (ei_affiliated[0], ei_affiliated[1]),
('institution', 'affiliate', 'author'): (ei_affiliated[1], ei_affiliated[0]),
('paper', 'cite', 'paper'): (np.concatenate([ei_cites[0], ei_cites[1]]), np.concatenate([ei_cites[1], ei_cites[0]]))
})
g = dgl.heterograph(
{
("author", "write", "paper"): (ei_writes[0], ei_writes[1]),
("paper", "write-by", "author"): (ei_writes[1], ei_writes[0]),
("author", "affiliate-with", "institution"): (
ei_affiliated[0],
ei_affiliated[1],
),
("institution", "affiliate", "author"): (
ei_affiliated[1],
ei_affiliated[0],
),
("paper", "cite", "paper"): (
np.concatenate([ei_cites[0], ei_cites[1]]),
np.concatenate([ei_cites[1], ei_cites[0]]),
),
}
)
paper_feat = dataset.paper_feat
author_feat = np.memmap(args.author_output_path, mode='w+', dtype='float16', shape=(dataset.num_authors, dataset.num_paper_features))
inst_feat = np.memmap(args.inst_output_path, mode='w+', dtype='float16', shape=(dataset.num_institutions, dataset.num_paper_features))
author_feat = np.memmap(
args.author_output_path,
mode="w+",
dtype="float16",
shape=(dataset.num_authors, dataset.num_paper_features),
)
inst_feat = np.memmap(
args.inst_output_path,
mode="w+",
dtype="float16",
shape=(dataset.num_institutions, dataset.num_paper_features),
)
# Iteratively process author features along the feature dimension.
BLOCK_COLS = 16
with tqdm.trange(0, dataset.num_paper_features, BLOCK_COLS) as tq:
for start in tq:
tq.set_postfix_str('Reading paper features...')
g.nodes['paper'].data['x'] = torch.FloatTensor(paper_feat[:, start:start + BLOCK_COLS].astype('float32'))
tq.set_postfix_str("Reading paper features...")
g.nodes["paper"].data["x"] = torch.FloatTensor(
paper_feat[:, start : start + BLOCK_COLS].astype("float32")
)
# Compute author features...
tq.set_postfix_str('Computing author features...')
g.update_all(fn.copy_u('x', 'm'), fn.mean('m', 'x'), etype='write-by')
tq.set_postfix_str("Computing author features...")
g.update_all(fn.copy_u("x", "m"), fn.mean("m", "x"), etype="write-by")
# Then institution features...
tq.set_postfix_str('Computing institution features...')
g.update_all(fn.copy_u('x', 'm'), fn.mean('m', 'x'), etype='affiliate-with')
tq.set_postfix_str('Writing author features...')
author_feat[:, start:start + BLOCK_COLS] = g.nodes['author'].data['x'].numpy().astype('float16')
tq.set_postfix_str('Writing institution features...')
inst_feat[:, start:start + BLOCK_COLS] = g.nodes['institution'].data['x'].numpy().astype('float16')
del g.nodes['paper'].data['x']
del g.nodes['author'].data['x']
del g.nodes['institution'].data['x']
tq.set_postfix_str("Computing institution features...")
g.update_all(
fn.copy_u("x", "m"), fn.mean("m", "x"), etype="affiliate-with"
)
tq.set_postfix_str("Writing author features...")
author_feat[:, start : start + BLOCK_COLS] = (
g.nodes["author"].data["x"].numpy().astype("float16")
)
tq.set_postfix_str("Writing institution features...")
inst_feat[:, start : start + BLOCK_COLS] = (
g.nodes["institution"].data["x"].numpy().astype("float16")
)
del g.nodes["paper"].data["x"]
del g.nodes["author"].data["x"]
del g.nodes["institution"].data["x"]
author_feat.flush()
inst_feat.flush()
......@@ -73,34 +128,56 @@ if args.graph_as_homogeneous:
# DGL also ensures that the node types are sorted in ascending order.
assert torch.equal(
g.ndata[dgl.NTYPE],
torch.cat([torch.full((dataset.num_authors,), 0),
torch.full((dataset.num_institutions,), 1),
torch.full((dataset.num_papers,), 2)]))
torch.cat(
[
torch.full((dataset.num_authors,), 0),
torch.full((dataset.num_institutions,), 1),
torch.full((dataset.num_papers,), 2),
]
),
)
assert torch.equal(
g.ndata[dgl.NID],
torch.cat([torch.arange(dataset.num_authors),
torch.arange(dataset.num_institutions),
torch.arange(dataset.num_papers)]))
g.edata['etype'] = g.edata[dgl.ETYPE].byte()
torch.cat(
[
torch.arange(dataset.num_authors),
torch.arange(dataset.num_institutions),
torch.arange(dataset.num_papers),
]
),
)
g.edata["etype"] = g.edata[dgl.ETYPE].byte()
del g.edata[dgl.ETYPE]
del g.ndata[dgl.NTYPE]
del g.ndata[dgl.NID]
# Process feature
full_feat = np.memmap(
args.full_output_path, mode='w+', dtype='float16',
shape=(dataset.num_authors + dataset.num_institutions + dataset.num_papers, dataset.num_paper_features))
args.full_output_path,
mode="w+",
dtype="float16",
shape=(
dataset.num_authors + dataset.num_institutions + dataset.num_papers,
dataset.num_paper_features,
),
)
BLOCK_ROWS = 100000
for start in tqdm.trange(0, dataset.num_authors, BLOCK_ROWS):
end = min(dataset.num_authors, start + BLOCK_ROWS)
full_feat[author_offset + start:author_offset + end] = author_feat[start:end]
full_feat[author_offset + start : author_offset + end] = author_feat[
start:end
]
for start in tqdm.trange(0, dataset.num_institutions, BLOCK_ROWS):
end = min(dataset.num_institutions, start + BLOCK_ROWS)
full_feat[inst_offset + start:inst_offset + end] = inst_feat[start:end]
full_feat[inst_offset + start : inst_offset + end] = inst_feat[
start:end
]
for start in tqdm.trange(0, dataset.num_papers, BLOCK_ROWS):
end = min(dataset.num_papers, start + BLOCK_ROWS)
full_feat[paper_offset + start:paper_offset + end] = paper_feat[start:end]
full_feat[paper_offset + start : paper_offset + end] = paper_feat[
start:end
]
# Convert the graph to the given format and save. (The RGAT baseline needs CSC graph)
g = g.formats(args.graph_format)
dgl.save_graphs(args.graph_output_path, g)
#!/usr/bin/env python
# coding: utf-8
import argparse
import time
import numpy as np
import ogb
from ogb.lsc import MAG240MDataset, MAG240MEvaluator
import dgl
import torch
import numpy as np
import time
import torch.nn as nn
import torch.nn.functional as F
import tqdm
from ogb.lsc import MAG240MDataset, MAG240MEvaluator
import dgl
import dgl.function as fn
import numpy as np
import dgl.nn as dglnn
import torch.nn as nn
import torch.nn.functional as F
import argparse
class RGAT(nn.Module):
def __init__(self, in_channels, out_channels, hidden_channels, num_etypes, num_layers, num_heads, dropout, pred_ntype):
def __init__(
self,
in_channels,
out_channels,
hidden_channels,
num_etypes,
num_layers,
num_heads,
dropout,
pred_ntype,
):
super().__init__()
self.convs = nn.ModuleList()
self.norms = nn.ModuleList()
self.skips = nn.ModuleList()
self.convs.append(nn.ModuleList([
dglnn.GATConv(in_channels, hidden_channels // num_heads, num_heads, allow_zero_in_degree=True)
for _ in range(num_etypes)
]))
self.convs.append(
nn.ModuleList(
[
dglnn.GATConv(
in_channels,
hidden_channels // num_heads,
num_heads,
allow_zero_in_degree=True,
)
for _ in range(num_etypes)
]
)
)
self.norms.append(nn.BatchNorm1d(hidden_channels))
self.skips.append(nn.Linear(in_channels, hidden_channels))
for _ in range(num_layers - 1):
self.convs.append(nn.ModuleList([
dglnn.GATConv(hidden_channels, hidden_channels // num_heads, num_heads, allow_zero_in_degree=True)
for _ in range(num_etypes)
]))
self.convs.append(
nn.ModuleList(
[
dglnn.GATConv(
hidden_channels,
hidden_channels // num_heads,
num_heads,
allow_zero_in_degree=True,
)
for _ in range(num_etypes)
]
)
)
self.norms.append(nn.BatchNorm1d(hidden_channels))
self.skips.append(nn.Linear(hidden_channels, hidden_channels))
self.mlp = nn.Sequential(
nn.Linear(hidden_channels, hidden_channels),
nn.BatchNorm1d(hidden_channels),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_channels, out_channels)
nn.Linear(hidden_channels, out_channels),
)
self.dropout = nn.Dropout(dropout)
self.hidden_channels = hidden_channels
self.pred_ntype = pred_ntype
self.num_etypes = num_etypes
def forward(self, mfgs, x):
for i in range(len(mfgs)):
mfg = mfgs[i]
x_dst = x[:mfg.num_dst_nodes()]
x_dst = x[: mfg.num_dst_nodes()]
n_src = mfg.num_src_nodes()
n_dst = mfg.num_dst_nodes()
mfg = dgl.block_to_graph(mfg)
x_skip = self.skips[i](x_dst)
for j in range(self.num_etypes):
subg = mfg.edge_subgraph(mfg.edata['etype'] == j, relabel_nodes=False)
x_skip += self.convs[i][j](subg, (x, x_dst)).view(-1, self.hidden_channels)
subg = mfg.edge_subgraph(
mfg.edata["etype"] == j, relabel_nodes=False
)
x_skip += self.convs[i][j](subg, (x, x_dst)).view(
-1, self.hidden_channels
)
x = self.norms[i](x_skip)
x = F.elu(x)
x = self.dropout(x)
......@@ -76,27 +110,34 @@ class ExternalNodeCollator(dgl.dataloading.NodeCollator):
def collate(self, items):
input_nodes, output_nodes, mfgs = super().collate(items)
# Copy input features
mfgs[0].srcdata['x'] = torch.FloatTensor(self.feats[input_nodes])
mfgs[-1].dstdata['y'] = torch.LongTensor(self.label[output_nodes - self.offset])
mfgs[0].srcdata["x"] = torch.FloatTensor(self.feats[input_nodes])
mfgs[-1].dstdata["y"] = torch.LongTensor(
self.label[output_nodes - self.offset]
)
return input_nodes, output_nodes, mfgs
def train(args, dataset, g, feats, paper_offset):
print('Loading masks and labels')
train_idx = torch.LongTensor(dataset.get_idx_split('train')) + paper_offset
valid_idx = torch.LongTensor(dataset.get_idx_split('valid')) + paper_offset
print("Loading masks and labels")
train_idx = torch.LongTensor(dataset.get_idx_split("train")) + paper_offset
valid_idx = torch.LongTensor(dataset.get_idx_split("valid")) + paper_offset
label = dataset.paper_label
print('Initializing dataloader...')
print("Initializing dataloader...")
sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 25])
train_collator = ExternalNodeCollator(g, train_idx, sampler, paper_offset, feats, label)
valid_collator = ExternalNodeCollator(g, valid_idx, sampler, paper_offset, feats, label)
train_collator = ExternalNodeCollator(
g, train_idx, sampler, paper_offset, feats, label
)
valid_collator = ExternalNodeCollator(
g, valid_idx, sampler, paper_offset, feats, label
)
train_dataloader = torch.utils.data.DataLoader(
train_collator.dataset,
batch_size=1024,
shuffle=True,
drop_last=False,
collate_fn=train_collator.collate,
num_workers=4
num_workers=4,
)
valid_dataloader = torch.utils.data.DataLoader(
valid_collator.dataset,
......@@ -104,11 +145,20 @@ def train(args, dataset, g, feats, paper_offset):
shuffle=True,
drop_last=False,
collate_fn=valid_collator.collate,
num_workers=2
num_workers=2,
)
print('Initializing model...')
model = RGAT(dataset.num_paper_features, dataset.num_classes, 1024, 5, 2, 4, 0.5, 'paper').cuda()
print("Initializing model...")
model = RGAT(
dataset.num_paper_features,
dataset.num_classes,
1024,
5,
2,
4,
0.5,
"paper",
).cuda()
opt = torch.optim.Adam(model.parameters(), lr=0.001)
sched = torch.optim.lr_scheduler.StepLR(opt, step_size=25, gamma=0.25)
......@@ -118,114 +168,170 @@ def train(args, dataset, g, feats, paper_offset):
model.train()
with tqdm.tqdm(train_dataloader) as tq:
for i, (input_nodes, output_nodes, mfgs) in enumerate(tq):
mfgs = [g.to('cuda') for g in mfgs]
x = mfgs[0].srcdata['x']
y = mfgs[-1].dstdata['y']
mfgs = [g.to("cuda") for g in mfgs]
x = mfgs[0].srcdata["x"]
y = mfgs[-1].dstdata["y"]
y_hat = model(mfgs, x)
loss = F.cross_entropy(y_hat, y)
opt.zero_grad()
loss.backward()
opt.step()
acc = (y_hat.argmax(1) == y).float().mean()
tq.set_postfix({'loss': '%.4f' % loss.item(), 'acc': '%.4f' % acc.item()}, refresh=False)
tq.set_postfix(
{"loss": "%.4f" % loss.item(), "acc": "%.4f" % acc.item()},
refresh=False,
)
model.eval()
correct = total = 0
for i, (input_nodes, output_nodes, mfgs) in enumerate(tqdm.tqdm(valid_dataloader)):
for i, (input_nodes, output_nodes, mfgs) in enumerate(
tqdm.tqdm(valid_dataloader)
):
with torch.no_grad():
mfgs = [g.to('cuda') for g in mfgs]
x = mfgs[0].srcdata['x']
y = mfgs[-1].dstdata['y']
mfgs = [g.to("cuda") for g in mfgs]
x = mfgs[0].srcdata["x"]
y = mfgs[-1].dstdata["y"]
y_hat = model(mfgs, x)
correct += (y_hat.argmax(1) == y).sum().item()
total += y_hat.shape[0]
acc = correct / total
print('Validation accuracy:', acc)
print("Validation accuracy:", acc)
sched.step()
if best_acc < acc:
best_acc = acc
print('Updating best model...')
print("Updating best model...")
torch.save(model.state_dict(), args.model_path)
def test(args, dataset, g, feats, paper_offset):
print('Loading masks and labels...')
valid_idx = torch.LongTensor(dataset.get_idx_split('valid')) + paper_offset
test_idx = torch.LongTensor(dataset.get_idx_split('test')) + paper_offset
print("Loading masks and labels...")
valid_idx = torch.LongTensor(dataset.get_idx_split("valid")) + paper_offset
test_idx = torch.LongTensor(dataset.get_idx_split("test")) + paper_offset
label = dataset.paper_label
print('Initializing data loader...')
print("Initializing data loader...")
sampler = dgl.dataloading.MultiLayerNeighborSampler([160, 160])
valid_collator = ExternalNodeCollator(g, valid_idx, sampler, paper_offset, feats, label)
valid_collator = ExternalNodeCollator(
g, valid_idx, sampler, paper_offset, feats, label
)
valid_dataloader = torch.utils.data.DataLoader(
valid_collator.dataset,
batch_size=16,
shuffle=False,
drop_last=False,
collate_fn=valid_collator.collate,
num_workers=2
num_workers=2,
)
test_collator = ExternalNodeCollator(
g, test_idx, sampler, paper_offset, feats, label
)
test_collator = ExternalNodeCollator(g, test_idx, sampler, paper_offset, feats, label)
test_dataloader = torch.utils.data.DataLoader(
test_collator.dataset,
batch_size=16,
shuffle=False,
drop_last=False,
collate_fn=test_collator.collate,
num_workers=4
num_workers=4,
)
print('Loading model...')
model = RGAT(dataset.num_paper_features, dataset.num_classes, 1024, 5, 2, 4, 0.5, 'paper').cuda()
print("Loading model...")
model = RGAT(
dataset.num_paper_features,
dataset.num_classes,
1024,
5,
2,
4,
0.5,
"paper",
).cuda()
model.load_state_dict(torch.load(args.model_path))
model.eval()
correct = total = 0
for i, (input_nodes, output_nodes, mfgs) in enumerate(tqdm.tqdm(valid_dataloader)):
for i, (input_nodes, output_nodes, mfgs) in enumerate(
tqdm.tqdm(valid_dataloader)
):
with torch.no_grad():
mfgs = [g.to('cuda') for g in mfgs]
x = mfgs[0].srcdata['x']
y = mfgs[-1].dstdata['y']
mfgs = [g.to("cuda") for g in mfgs]
x = mfgs[0].srcdata["x"]
y = mfgs[-1].dstdata["y"]
y_hat = model(mfgs, x)
correct += (y_hat.argmax(1) == y).sum().item()
total += y_hat.shape[0]
acc = correct / total
print('Validation accuracy:', acc)
print("Validation accuracy:", acc)
evaluator = MAG240MEvaluator()
y_preds = []
for i, (input_nodes, output_nodes, mfgs) in enumerate(tqdm.tqdm(test_dataloader)):
for i, (input_nodes, output_nodes, mfgs) in enumerate(
tqdm.tqdm(test_dataloader)
):
with torch.no_grad():
mfgs = [g.to('cuda') for g in mfgs]
x = mfgs[0].srcdata['x']
y = mfgs[-1].dstdata['y']
mfgs = [g.to("cuda") for g in mfgs]
x = mfgs[0].srcdata["x"]
y = mfgs[-1].dstdata["y"]
y_hat = model(mfgs, x)
y_preds.append(y_hat.argmax(1).cpu())
evaluator.save_test_submission({'y_pred': torch.cat(y_preds)}, args.submission_path)
evaluator.save_test_submission(
{"y_pred": torch.cat(y_preds)}, args.submission_path
)
if __name__ == '__main__':
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--rootdir', type=str, default='.', help='Directory to download the OGB dataset.')
parser.add_argument('--graph-path', type=str, default='./graph.dgl', help='Path to the graph.')
parser.add_argument('--full-feature-path', type=str, default='./full.npy',
help='Path to the features of all nodes.')
parser.add_argument('--epochs', type=int, default=100, help='Number of epochs.')
parser.add_argument('--model-path', type=str, default='./model.pt', help='Path to store the best model.')
parser.add_argument('--submission-path', type=str, default='./results', help='Submission directory.')
parser.add_argument(
"--rootdir",
type=str,
default=".",
help="Directory to download the OGB dataset.",
)
parser.add_argument(
"--graph-path",
type=str,
default="./graph.dgl",
help="Path to the graph.",
)
parser.add_argument(
"--full-feature-path",
type=str,
default="./full.npy",
help="Path to the features of all nodes.",
)
parser.add_argument(
"--epochs", type=int, default=100, help="Number of epochs."
)
parser.add_argument(
"--model-path",
type=str,
default="./model.pt",
help="Path to store the best model.",
)
parser.add_argument(
"--submission-path",
type=str,
default="./results",
help="Submission directory.",
)
args = parser.parse_args()
dataset = MAG240MDataset(root=args.rootdir)
print('Loading graph')
print("Loading graph")
(g,), _ = dgl.load_graphs(args.graph_path)
g = g.formats(['csc'])
g = g.formats(["csc"])
print('Loading features')
print("Loading features")
paper_offset = dataset.num_authors + dataset.num_institutions
num_nodes = paper_offset + dataset.num_papers
num_features = dataset.num_paper_features
feats = np.memmap(args.full_feature_path, mode='r', dtype='float16', shape=(num_nodes, num_features))
feats = np.memmap(
args.full_feature_path,
mode="r",
dtype="float16",
shape=(num_nodes, num_features),
)
if args.epochs != 0:
train(args, dataset, g, feats, paper_offset)
......
#!/usr/bin/env python
# coding: utf-8
import argparse
import math
import sys
from collections import OrderedDict
from ogb.lsc import MAG240MDataset, MAG240MEvaluator
import dgl
import torch
import tqdm
import numpy as np
import dgl.nn as dglnn
import torch
import torch.multiprocessing as mp
import torch.nn as nn
import torch.nn.functional as F
import argparse
import torch.multiprocessing as mp
import sys
import tqdm
from ogb.lsc import MAG240MDataset, MAG240MEvaluator
from torch.nn.parallel import DistributedDataParallel
from collections import OrderedDict
import dgl
import dgl.nn as dglnn
class RGAT(nn.Module):
def __init__(self, in_channels, out_channels, hidden_channels, num_etypes, num_layers, num_heads, dropout,
pred_ntype):
def __init__(
self,
in_channels,
out_channels,
hidden_channels,
num_etypes,
num_layers,
num_heads,
dropout,
pred_ntype,
):
super().__init__()
self.convs = nn.ModuleList()
self.norms = nn.ModuleList()
self.skips = nn.ModuleList()
self.convs.append(nn.ModuleList([
dglnn.GATConv(in_channels, hidden_channels // num_heads, num_heads, allow_zero_in_degree=True)
for _ in range(num_etypes)
]))
self.convs.append(
nn.ModuleList(
[
dglnn.GATConv(
in_channels,
hidden_channels // num_heads,
num_heads,
allow_zero_in_degree=True,
)
for _ in range(num_etypes)
]
)
)
self.norms.append(nn.BatchNorm1d(hidden_channels))
self.skips.append(nn.Linear(in_channels, hidden_channels))
for _ in range(num_layers - 1):
self.convs.append(nn.ModuleList([
dglnn.GATConv(hidden_channels, hidden_channels // num_heads, num_heads, allow_zero_in_degree=True)
for _ in range(num_etypes)
]))
self.convs.append(
nn.ModuleList(
[
dglnn.GATConv(
hidden_channels,
hidden_channels // num_heads,
num_heads,
allow_zero_in_degree=True,
)
for _ in range(num_etypes)
]
)
)
self.norms.append(nn.BatchNorm1d(hidden_channels))
self.skips.append(nn.Linear(hidden_channels, hidden_channels))
......@@ -44,7 +72,7 @@ class RGAT(nn.Module):
nn.BatchNorm1d(hidden_channels),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_channels, out_channels)
nn.Linear(hidden_channels, out_channels),
)
self.dropout = nn.Dropout(dropout)
......@@ -55,14 +83,18 @@ class RGAT(nn.Module):
def forward(self, mfgs, x):
for i in range(len(mfgs)):
mfg = mfgs[i]
x_dst = x[:mfg.num_dst_nodes()]
x_dst = x[: mfg.num_dst_nodes()]
n_src = mfg.num_src_nodes()
n_dst = mfg.num_dst_nodes()
mfg = dgl.block_to_graph(mfg)
x_skip = self.skips[i](x_dst)
for j in range(self.num_etypes):
subg = mfg.edge_subgraph(mfg.edata['etype'] == j, relabel_nodes=False)
x_skip += self.convs[i][j](subg, (x, x_dst)).view(-1, self.hidden_channels)
subg = mfg.edge_subgraph(
mfg.edata["etype"] == j, relabel_nodes=False
)
x_skip += self.convs[i][j](subg, (x, x_dst)).view(
-1, self.hidden_channels
)
x = self.norms[i](x_skip)
x = F.elu(x)
x = self.dropout(x)
......@@ -79,46 +111,65 @@ class ExternalNodeCollator(dgl.dataloading.NodeCollator):
def collate(self, items):
input_nodes, output_nodes, mfgs = super().collate(items)
# Copy input features
mfgs[0].srcdata['x'] = torch.FloatTensor(self.feats[input_nodes])
mfgs[-1].dstdata['y'] = torch.LongTensor(self.label[output_nodes - self.offset])
mfgs[0].srcdata["x"] = torch.FloatTensor(self.feats[input_nodes])
mfgs[-1].dstdata["y"] = torch.LongTensor(
self.label[output_nodes - self.offset]
)
return input_nodes, output_nodes, mfgs
def train(proc_id, n_gpus, args, dataset, g, feats, paper_offset):
dev_id = devices[proc_id]
if n_gpus > 1:
dist_init_method = 'tcp://{master_ip}:{master_port}'.format(
master_ip='127.0.0.1', master_port='12346')
dist_init_method = "tcp://{master_ip}:{master_port}".format(
master_ip="127.0.0.1", master_port="12346"
)
world_size = n_gpus
torch.distributed.init_process_group(backend='nccl',
init_method=dist_init_method,
world_size=world_size,
rank=proc_id)
torch.distributed.init_process_group(
backend="nccl",
init_method=dist_init_method,
world_size=world_size,
rank=proc_id,
)
torch.cuda.set_device(dev_id)
print('Loading masks and labels')
train_idx = torch.LongTensor(dataset.get_idx_split('train')) + paper_offset
valid_idx = torch.LongTensor(dataset.get_idx_split('valid')) + paper_offset
print("Loading masks and labels")
train_idx = torch.LongTensor(dataset.get_idx_split("train")) + paper_offset
valid_idx = torch.LongTensor(dataset.get_idx_split("valid")) + paper_offset
label = dataset.paper_label
print('Initializing dataloader...')
print("Initializing dataloader...")
sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 25])
train_collator = ExternalNodeCollator(g, train_idx, sampler, paper_offset, feats, label)
valid_collator = ExternalNodeCollator(g, valid_idx, sampler, paper_offset, feats, label)
train_collator = ExternalNodeCollator(
g, train_idx, sampler, paper_offset, feats, label
)
valid_collator = ExternalNodeCollator(
g, valid_idx, sampler, paper_offset, feats, label
)
# Necessary according to https://yangkky.github.io/2019/07/08/distributed-pytorch-tutorial.html
train_sampler = torch.utils.data.distributed.DistributedSampler(
train_collator.dataset, num_replicas=world_size, rank=proc_id, shuffle=True, drop_last=False)
train_collator.dataset,
num_replicas=world_size,
rank=proc_id,
shuffle=True,
drop_last=False,
)
valid_sampler = torch.utils.data.distributed.DistributedSampler(
valid_collator.dataset, num_replicas=world_size, rank=proc_id, shuffle=True, drop_last=False)
valid_collator.dataset,
num_replicas=world_size,
rank=proc_id,
shuffle=True,
drop_last=False,
)
train_dataloader = torch.utils.data.DataLoader(
train_collator.dataset,
batch_size=1024,
collate_fn=train_collator.collate,
num_workers=4,
sampler=train_sampler
sampler=train_sampler,
)
valid_dataloader = torch.utils.data.DataLoader(
......@@ -126,16 +177,27 @@ def train(proc_id, n_gpus, args, dataset, g, feats, paper_offset):
batch_size=1024,
collate_fn=valid_collator.collate,
num_workers=2,
sampler=valid_sampler
sampler=valid_sampler,
)
print('Initializing model...')
model = RGAT(dataset.num_paper_features, dataset.num_classes, 1024, 5, 2, 4, 0.5, 'paper').to(dev_id)
print("Initializing model...")
model = RGAT(
dataset.num_paper_features,
dataset.num_classes,
1024,
5,
2,
4,
0.5,
"paper",
).to(dev_id)
# convert BN to SyncBatchNorm. see https://pytorch.org/docs/stable/generated/torch.nn.SyncBatchNorm.html
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
model = DistributedDataParallel(model, device_ids=[dev_id], output_device=dev_id)
model = DistributedDataParallel(
model, device_ids=[dev_id], output_device=dev_id
)
opt = torch.optim.Adam(model.parameters(), lr=0.001)
sched = torch.optim.lr_scheduler.StepLR(opt, step_size=25, gamma=0.25)
......@@ -149,75 +211,97 @@ def train(proc_id, n_gpus, args, dataset, g, feats, paper_offset):
with tqdm.tqdm(train_dataloader) as tq:
for i, (input_nodes, output_nodes, mfgs) in enumerate(tq):
mfgs = [g.to(dev_id) for g in mfgs]
x = mfgs[0].srcdata['x']
y = mfgs[-1].dstdata['y']
x = mfgs[0].srcdata["x"]
y = mfgs[-1].dstdata["y"]
y_hat = model(mfgs, x)
loss = F.cross_entropy(y_hat, y)
opt.zero_grad()
loss.backward()
opt.step()
acc = (y_hat.argmax(1) == y).float().mean()
tq.set_postfix({'loss': '%.4f' % loss.item(), 'acc': '%.4f' % acc.item()}, refresh=False)
tq.set_postfix(
{"loss": "%.4f" % loss.item(), "acc": "%.4f" % acc.item()},
refresh=False,
)
# eval in each process
model.eval()
correct = torch.LongTensor([0]).to(dev_id)
total = torch.LongTensor([0]).to(dev_id)
for i, (input_nodes, output_nodes, mfgs) in enumerate(tqdm.tqdm(valid_dataloader)):
for i, (input_nodes, output_nodes, mfgs) in enumerate(
tqdm.tqdm(valid_dataloader)
):
with torch.no_grad():
mfgs = [g.to(dev_id) for g in mfgs]
x = mfgs[0].srcdata['x']
y = mfgs[-1].dstdata['y']
x = mfgs[0].srcdata["x"]
y = mfgs[-1].dstdata["y"]
y_hat = model(mfgs, x)
correct += (y_hat.argmax(1) == y).sum().item()
total += y_hat.shape[0]
# `reduce` data into process 0
torch.distributed.reduce(correct, dst=0, op=torch.distributed.ReduceOp.SUM)
torch.distributed.reduce(total, dst=0, op=torch.distributed.ReduceOp.SUM)
torch.distributed.reduce(
correct, dst=0, op=torch.distributed.ReduceOp.SUM
)
torch.distributed.reduce(
total, dst=0, op=torch.distributed.ReduceOp.SUM
)
acc = (correct / total).item()
sched.step()
# process 0 print accuracy and save model
if proc_id == 0:
print('Validation accuracy:', acc)
print("Validation accuracy:", acc)
if best_acc < acc:
best_acc = acc
print('Updating best model...')
print("Updating best model...")
torch.save(model.state_dict(), args.model_path)
def test(args, dataset, g, feats, paper_offset):
print('Loading masks and labels...')
valid_idx = torch.LongTensor(dataset.get_idx_split('valid')) + paper_offset
test_idx = torch.LongTensor(dataset.get_idx_split('test')) + paper_offset
print("Loading masks and labels...")
valid_idx = torch.LongTensor(dataset.get_idx_split("valid")) + paper_offset
test_idx = torch.LongTensor(dataset.get_idx_split("test")) + paper_offset
label = dataset.paper_label
print('Initializing data loader...')
print("Initializing data loader...")
sampler = dgl.dataloading.MultiLayerNeighborSampler([160, 160])
valid_collator = ExternalNodeCollator(g, valid_idx, sampler, paper_offset, feats, label)
valid_collator = ExternalNodeCollator(
g, valid_idx, sampler, paper_offset, feats, label
)
valid_dataloader = torch.utils.data.DataLoader(
valid_collator.dataset,
batch_size=16,
shuffle=False,
drop_last=False,
collate_fn=valid_collator.collate,
num_workers=2
num_workers=2,
)
test_collator = ExternalNodeCollator(
g, test_idx, sampler, paper_offset, feats, label
)
test_collator = ExternalNodeCollator(g, test_idx, sampler, paper_offset, feats, label)
test_dataloader = torch.utils.data.DataLoader(
test_collator.dataset,
batch_size=16,
shuffle=False,
drop_last=False,
collate_fn=test_collator.collate,
num_workers=4
num_workers=4,
)
print('Loading model...')
model = RGAT(dataset.num_paper_features, dataset.num_classes, 1024, 5, 2, 4, 0.5, 'paper').cuda()
print("Loading model...")
model = RGAT(
dataset.num_paper_features,
dataset.num_classes,
1024,
5,
2,
4,
0.5,
"paper",
).cuda()
# load ddp's model parameters, we need to remove the name of 'module.'
state_dict = torch.load(args.model_path)
......@@ -229,41 +313,73 @@ def test(args, dataset, g, feats, paper_offset):
model.eval()
correct = total = 0
for i, (input_nodes, output_nodes, mfgs) in enumerate(tqdm.tqdm(valid_dataloader)):
for i, (input_nodes, output_nodes, mfgs) in enumerate(
tqdm.tqdm(valid_dataloader)
):
with torch.no_grad():
mfgs = [g.to('cuda') for g in mfgs]
x = mfgs[0].srcdata['x']
y = mfgs[-1].dstdata['y']
mfgs = [g.to("cuda") for g in mfgs]
x = mfgs[0].srcdata["x"]
y = mfgs[-1].dstdata["y"]
y_hat = model(mfgs, x)
correct += (y_hat.argmax(1) == y).sum().item()
total += y_hat.shape[0]
acc = correct / total
print('Validation accuracy:', acc)
print("Validation accuracy:", acc)
evaluator = MAG240MEvaluator()
y_preds = []
for i, (input_nodes, output_nodes, mfgs) in enumerate(tqdm.tqdm(test_dataloader)):
for i, (input_nodes, output_nodes, mfgs) in enumerate(
tqdm.tqdm(test_dataloader)
):
with torch.no_grad():
mfgs = [g.to('cuda') for g in mfgs]
x = mfgs[0].srcdata['x']
y = mfgs[-1].dstdata['y']
mfgs = [g.to("cuda") for g in mfgs]
x = mfgs[0].srcdata["x"]
y = mfgs[-1].dstdata["y"]
y_hat = model(mfgs, x)
y_preds.append(y_hat.argmax(1).cpu())
evaluator.save_test_submission({'y_pred': torch.cat(y_preds)}, args.submission_path)
evaluator.save_test_submission(
{"y_pred": torch.cat(y_preds)}, args.submission_path
)
if __name__ == '__main__':
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--rootdir', type=str, default='.', help='Directory to download the OGB dataset.')
parser.add_argument('--graph-path', type=str, default='./graph.dgl', help='Path to the graph.')
parser.add_argument('--full-feature-path', type=str, default='./full.npy',
help='Path to the features of all nodes.')
parser.add_argument('--epochs', type=int, default=100, help='Number of epochs.')
parser.add_argument('--model-path', type=str, default='./model_ddp.pt', help='Path to store the best model.')
parser.add_argument('--submission-path', type=str, default='./results_ddp', help='Submission directory.')
parser.add_argument('--gpus', type=str, default='0,1,2')
parser.add_argument(
"--rootdir",
type=str,
default=".",
help="Directory to download the OGB dataset.",
)
parser.add_argument(
"--graph-path",
type=str,
default="./graph.dgl",
help="Path to the graph.",
)
parser.add_argument(
"--full-feature-path",
type=str,
default="./full.npy",
help="Path to the features of all nodes.",
)
parser.add_argument(
"--epochs", type=int, default=100, help="Number of epochs."
)
parser.add_argument(
"--model-path",
type=str,
default="./model_ddp.pt",
help="Path to store the best model.",
)
parser.add_argument(
"--submission-path",
type=str,
default="./results_ddp",
help="Submission directory.",
)
parser.add_argument("--gpus", type=str, default="0,1,2")
args = parser.parse_args()
devices = list(map(int, args.gpus.split(',')))
devices = list(map(int, args.gpus.split(",")))
n_gpus = len(devices)
if n_gpus <= 1:
......@@ -272,16 +388,25 @@ if __name__ == '__main__':
dataset = MAG240MDataset(root=args.rootdir)
print('Loading graph')
print("Loading graph")
(g,), _ = dgl.load_graphs(args.graph_path)
g = g.formats(['csc'])
g = g.formats(["csc"])
print('Loading features')
print("Loading features")
paper_offset = dataset.num_authors + dataset.num_institutions
num_nodes = paper_offset + dataset.num_papers
num_features = dataset.num_paper_features
feats = np.memmap(args.full_feature_path, mode='r', dtype='float16', shape=(num_nodes, num_features))
feats = np.memmap(
args.full_feature_path,
mode="r",
dtype="float16",
shape=(num_nodes, num_features),
)
mp.spawn(train, args=(n_gpus, args, dataset, g, feats, paper_offset), nprocs=n_gpus)
mp.spawn(
train,
args=(n_gpus, args, dataset, g, feats, paper_offset),
nprocs=n_gpus,
)
test(args, dataset, g, feats, paper_offset)
import dgl
import dgl.function as fn
import torch
import torch.nn as nn
import torch.nn.functional as F
from ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder
import dgl
import dgl.function as fn
from dgl.nn.pytorch import SumPooling
from ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder
### GIN convolution along the graph structure
class GINConv(nn.Module):
def __init__(self, emb_dim):
'''
emb_dim (int): node embedding dimensionality
'''
"""
emb_dim (int): node embedding dimensionality
"""
super(GINConv, self).__init__()
self.mlp = nn.Sequential(nn.Linear(emb_dim, emb_dim),
nn.BatchNorm1d(emb_dim),
nn.ReLU(),
nn.Linear(emb_dim, emb_dim))
self.mlp = nn.Sequential(
nn.Linear(emb_dim, emb_dim),
nn.BatchNorm1d(emb_dim),
nn.ReLU(),
nn.Linear(emb_dim, emb_dim),
)
self.eps = nn.Parameter(torch.Tensor([0]))
self.bond_encoder = BondEncoder(emb_dim = emb_dim)
self.bond_encoder = BondEncoder(emb_dim=emb_dim)
def forward(self, g, x, edge_attr):
with g.local_scope():
edge_embedding = self.bond_encoder(edge_attr)
g.ndata['x'] = x
g.apply_edges(fn.copy_u('x', 'm'))
g.edata['m'] = F.relu(g.edata['m'] + edge_embedding)
g.update_all(fn.copy_e('m', 'm'), fn.sum('m', 'new_x'))
out = self.mlp((1 + self.eps) * x + g.ndata['new_x'])
g.ndata["x"] = x
g.apply_edges(fn.copy_u("x", "m"))
g.edata["m"] = F.relu(g.edata["m"] + edge_embedding)
g.update_all(fn.copy_e("m", "m"), fn.sum("m", "new_x"))
out = self.mlp((1 + self.eps) * x + g.ndata["new_x"])
return out
### GCN convolution along the graph structure
class GCNConv(nn.Module):
def __init__(self, emb_dim):
'''
emb_dim (int): node embedding dimensionality
'''
"""
emb_dim (int): node embedding dimensionality
"""
super(GCNConv, self).__init__()
self.linear = nn.Linear(emb_dim, emb_dim)
self.root_emb = nn.Embedding(1, emb_dim)
self.bond_encoder = BondEncoder(emb_dim = emb_dim)
self.bond_encoder = BondEncoder(emb_dim=emb_dim)
def forward(self, g, x, edge_attr):
with g.local_scope():
......@@ -56,29 +60,43 @@ class GCNConv(nn.Module):
# Molecular graphs are undirected
# g.out_degrees() is the same as g.in_degrees()
degs = (g.out_degrees().float() + 1).to(x.device)
norm = torch.pow(degs, -0.5).unsqueeze(-1) # (N, 1)
g.ndata['norm'] = norm
g.apply_edges(fn.u_mul_v('norm', 'norm', 'norm'))
g.ndata['x'] = x
g.apply_edges(fn.copy_u('x', 'm'))
g.edata['m'] = g.edata['norm'] * F.relu(g.edata['m'] + edge_embedding)
g.update_all(fn.copy_e('m', 'm'), fn.sum('m', 'new_x'))
out = g.ndata['new_x'] + F.relu(x + self.root_emb.weight) * 1. / degs.view(-1, 1)
norm = torch.pow(degs, -0.5).unsqueeze(-1) # (N, 1)
g.ndata["norm"] = norm
g.apply_edges(fn.u_mul_v("norm", "norm", "norm"))
g.ndata["x"] = x
g.apply_edges(fn.copy_u("x", "m"))
g.edata["m"] = g.edata["norm"] * F.relu(
g.edata["m"] + edge_embedding
)
g.update_all(fn.copy_e("m", "m"), fn.sum("m", "new_x"))
out = g.ndata["new_x"] + F.relu(
x + self.root_emb.weight
) * 1.0 / degs.view(-1, 1)
return out
### GNN to generate node embedding
class GNN_node(nn.Module):
"""
Output:
node representations
"""
def __init__(self, num_layers, emb_dim, drop_ratio = 0.5, JK = "last", residual = False, gnn_type = 'gin'):
'''
num_layers (int): number of GNN message passing layers
emb_dim (int): node embedding dimensionality
'''
def __init__(
self,
num_layers,
emb_dim,
drop_ratio=0.5,
JK="last",
residual=False,
gnn_type="gin",
):
"""
num_layers (int): number of GNN message passing layers
emb_dim (int): node embedding dimensionality
"""
super(GNN_node, self).__init__()
self.num_layers = num_layers
......@@ -97,12 +115,12 @@ class GNN_node(nn.Module):
self.batch_norms = nn.ModuleList()
for layer in range(num_layers):
if gnn_type == 'gin':
if gnn_type == "gin":
self.convs.append(GINConv(emb_dim))
elif gnn_type == 'gcn':
elif gnn_type == "gcn":
self.convs.append(GCNConv(emb_dim))
else:
ValueError('Undefined GNN type called {}'.format(gnn_type))
ValueError("Undefined GNN type called {}".format(gnn_type))
self.batch_norms.append(nn.BatchNorm1d(emb_dim))
......@@ -115,10 +133,12 @@ class GNN_node(nn.Module):
h = self.batch_norms[layer](h)
if layer == self.num_layers - 1:
#remove relu for the last layer
h = F.dropout(h, self.drop_ratio, training = self.training)
# remove relu for the last layer
h = F.dropout(h, self.drop_ratio, training=self.training)
else:
h = F.dropout(F.relu(h), self.drop_ratio, training = self.training)
h = F.dropout(
F.relu(h), self.drop_ratio, training=self.training
)
if self.residual:
h += h_list[layer]
......@@ -142,11 +162,20 @@ class GNN_node_Virtualnode(nn.Module):
Output:
node representations
"""
def __init__(self, num_layers, emb_dim, drop_ratio = 0.5, JK = "last", residual = False, gnn_type = 'gin'):
'''
num_layers (int): number of GNN message passing layers
emb_dim (int): node embedding dimensionality
'''
def __init__(
self,
num_layers,
emb_dim,
drop_ratio=0.5,
JK="last",
residual=False,
gnn_type="gin",
):
"""
num_layers (int): number of GNN message passing layers
emb_dim (int): node embedding dimensionality
"""
super(GNN_node_Virtualnode, self).__init__()
self.num_layers = num_layers
......@@ -173,31 +202,38 @@ class GNN_node_Virtualnode(nn.Module):
self.mlp_virtualnode_list = nn.ModuleList()
for layer in range(num_layers):
if gnn_type == 'gin':
if gnn_type == "gin":
self.convs.append(GINConv(emb_dim))
elif gnn_type == 'gcn':
elif gnn_type == "gcn":
self.convs.append(GCNConv(emb_dim))
else:
ValueError('Undefined GNN type called {}'.format(gnn_type))
ValueError("Undefined GNN type called {}".format(gnn_type))
self.batch_norms.append(nn.BatchNorm1d(emb_dim))
for layer in range(num_layers - 1):
self.mlp_virtualnode_list.append(nn.Sequential(nn.Linear(emb_dim, emb_dim),
nn.BatchNorm1d(emb_dim),
nn.ReLU(),
nn.Linear(emb_dim, emb_dim),
nn.BatchNorm1d(emb_dim),
nn.ReLU()))
self.mlp_virtualnode_list.append(
nn.Sequential(
nn.Linear(emb_dim, emb_dim),
nn.BatchNorm1d(emb_dim),
nn.ReLU(),
nn.Linear(emb_dim, emb_dim),
nn.BatchNorm1d(emb_dim),
nn.ReLU(),
)
)
self.pool = SumPooling()
def forward(self, g, x, edge_attr):
### virtual node embeddings for graphs
virtualnode_embedding = self.virtualnode_embedding(
torch.zeros(g.batch_size).to(x.dtype).to(x.device))
torch.zeros(g.batch_size).to(x.dtype).to(x.device)
)
h_list = [self.atom_encoder(x)]
batch_id = dgl.broadcast_nodes(g, torch.arange(g.batch_size).to(x.device))
batch_id = dgl.broadcast_nodes(
g, torch.arange(g.batch_size).to(x.device)
)
for layer in range(self.num_layers):
### add message from virtual nodes to graph nodes
h_list[layer] = h_list[layer] + virtualnode_embedding[batch_id]
......@@ -206,10 +242,12 @@ class GNN_node_Virtualnode(nn.Module):
h = self.convs[layer](g, h_list[layer], edge_attr)
h = self.batch_norms[layer](h)
if layer == self.num_layers - 1:
#remove relu for the last layer
h = F.dropout(h, self.drop_ratio, training = self.training)
# remove relu for the last layer
h = F.dropout(h, self.drop_ratio, training=self.training)
else:
h = F.dropout(F.relu(h), self.drop_ratio, training = self.training)
h = F.dropout(
F.relu(h), self.drop_ratio, training=self.training
)
if self.residual:
h = h + h_list[layer]
......@@ -219,17 +257,26 @@ class GNN_node_Virtualnode(nn.Module):
### update the virtual nodes
if layer < self.num_layers - 1:
### add message from graph nodes to virtual nodes
virtualnode_embedding_temp = self.pool(g, h_list[layer]) + virtualnode_embedding
virtualnode_embedding_temp = (
self.pool(g, h_list[layer]) + virtualnode_embedding
)
### transform virtual nodes using MLP
virtualnode_embedding_temp = self.mlp_virtualnode_list[layer](
virtualnode_embedding_temp)
virtualnode_embedding_temp
)
if self.residual:
virtualnode_embedding = virtualnode_embedding + F.dropout(
virtualnode_embedding_temp, self.drop_ratio, training = self.training)
virtualnode_embedding_temp,
self.drop_ratio,
training=self.training,
)
else:
virtualnode_embedding = F.dropout(
virtualnode_embedding_temp, self.drop_ratio, training = self.training)
virtualnode_embedding_temp,
self.drop_ratio,
training=self.training,
)
### Different implementations of Jk-concat
if self.JK == "last":
......
import torch
import torch.nn as nn
from conv import GNN_node, GNN_node_Virtualnode
from dgl.nn.pytorch import SumPooling, AvgPooling, MaxPooling, GlobalAttentionPooling, Set2Set
from dgl.nn.pytorch import (
AvgPooling,
GlobalAttentionPooling,
MaxPooling,
Set2Set,
SumPooling,
)
from conv import GNN_node, GNN_node_Virtualnode
class GNN(nn.Module):
def __init__(self, num_tasks = 1, num_layers = 5, emb_dim = 300, gnn_type = 'gin',
virtual_node = True, residual = False, drop_ratio = 0, JK = "last",
graph_pooling = "sum"):
'''
num_tasks (int): number of labels to be predicted
virtual_node (bool): whether to add virtual node or not
'''
def __init__(
self,
num_tasks=1,
num_layers=5,
emb_dim=300,
gnn_type="gin",
virtual_node=True,
residual=False,
drop_ratio=0,
JK="last",
graph_pooling="sum",
):
"""
num_tasks (int): number of labels to be predicted
virtual_node (bool): whether to add virtual node or not
"""
super(GNN, self).__init__()
self.num_layers = num_layers
......@@ -28,14 +42,23 @@ class GNN(nn.Module):
### GNN to generate node embeddings
if virtual_node:
self.gnn_node = GNN_node_Virtualnode(num_layers, emb_dim, JK = JK,
drop_ratio = drop_ratio,
residual = residual,
gnn_type = gnn_type)
self.gnn_node = GNN_node_Virtualnode(
num_layers,
emb_dim,
JK=JK,
drop_ratio=drop_ratio,
residual=residual,
gnn_type=gnn_type,
)
else:
self.gnn_node = GNN_node(num_layers, emb_dim, JK = JK, drop_ratio = drop_ratio,
residual = residual, gnn_type = gnn_type)
self.gnn_node = GNN_node(
num_layers,
emb_dim,
JK=JK,
drop_ratio=drop_ratio,
residual=residual,
gnn_type=gnn_type,
)
### Pooling function to generate whole-graph embeddings
if self.graph_pooling == "sum":
......@@ -46,18 +69,21 @@ class GNN(nn.Module):
self.pool = MaxPooling
elif self.graph_pooling == "attention":
self.pool = GlobalAttentionPooling(
gate_nn = nn.Sequential(nn.Linear(emb_dim, 2*emb_dim),
nn.BatchNorm1d(2*emb_dim),
nn.ReLU(),
nn.Linear(2*emb_dim, 1)))
gate_nn=nn.Sequential(
nn.Linear(emb_dim, 2 * emb_dim),
nn.BatchNorm1d(2 * emb_dim),
nn.ReLU(),
nn.Linear(2 * emb_dim, 1),
)
)
elif self.graph_pooling == "set2set":
self.pool = Set2Set(emb_dim, n_iters = 2, n_layers = 2)
self.pool = Set2Set(emb_dim, n_iters=2, n_layers=2)
else:
raise ValueError("Invalid graph pooling type.")
if graph_pooling == "set2set":
self.graph_pred_linear = nn.Linear(2*self.emb_dim, self.num_tasks)
self.graph_pred_linear = nn.Linear(2 * self.emb_dim, self.num_tasks)
else:
self.graph_pred_linear = nn.Linear(self.emb_dim, self.num_tasks)
......
import argparse
import dgl
import numpy as np
import os
import random
import numpy as np
import torch
import torch.optim as optim
from gnn import GNN
from ogb.lsc import DglPCQM4MDataset, PCQM4MEvaluator
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import StepLR
from tqdm import tqdm
from gnn import GNN
import dgl
reg_criterion = torch.nn.L1Loss()
......@@ -30,11 +31,13 @@ def train(model, device, loader, optimizer):
for step, (bg, labels) in enumerate(tqdm(loader, desc="Iteration")):
bg = bg.to(device)
x = bg.ndata.pop('feat')
edge_attr = bg.edata.pop('feat')
x = bg.ndata.pop("feat")
edge_attr = bg.edata.pop("feat")
labels = labels.to(device)
pred = model(bg, x, edge_attr).view(-1,)
pred = model(bg, x, edge_attr).view(
-1,
)
optimizer.zero_grad()
loss = reg_criterion(pred, labels)
loss.backward()
......@@ -52,12 +55,14 @@ def eval(model, device, loader, evaluator):
for step, (bg, labels) in enumerate(tqdm(loader, desc="Iteration")):
bg = bg.to(device)
x = bg.ndata.pop('feat')
edge_attr = bg.edata.pop('feat')
x = bg.ndata.pop("feat")
edge_attr = bg.edata.pop("feat")
labels = labels.to(device)
with torch.no_grad():
pred = model(bg, x, edge_attr).view(-1, )
pred = model(bg, x, edge_attr).view(
-1,
)
y_true.append(labels.view(pred.shape).detach().cpu())
y_pred.append(pred.detach().cpu())
......@@ -76,11 +81,13 @@ def test(model, device, loader):
for step, (bg, _) in enumerate(tqdm(loader, desc="Iteration")):
bg = bg.to(device)
x = bg.ndata.pop('feat')
edge_attr = bg.edata.pop('feat')
x = bg.ndata.pop("feat")
edge_attr = bg.edata.pop("feat")
with torch.no_grad():
pred = model(bg, x, edge_attr).view(-1, )
pred = model(bg, x, edge_attr).view(
-1,
)
y_pred.append(pred.detach().cpu())
......@@ -91,37 +98,88 @@ def test(model, device, loader):
def main():
# Training settings
parser = argparse.ArgumentParser(description='GNN baselines on pcqm4m with DGL')
parser.add_argument('--seed', type=int, default=42,
help='random seed to use (default: 42)')
parser.add_argument('--device', type=int, default=0,
help='which gpu to use if any (default: 0)')
parser.add_argument('--gnn', type=str, default='gin-virtual',
help='GNN to use, which can be from '
'[gin, gin-virtual, gcn, gcn-virtual] (default: gin-virtual)')
parser.add_argument('--graph_pooling', type=str, default='sum',
help='graph pooling strategy mean or sum (default: sum)')
parser.add_argument('--drop_ratio', type=float, default=0,
help='dropout ratio (default: 0)')
parser.add_argument('--num_layers', type=int, default=5,
help='number of GNN message passing layers (default: 5)')
parser.add_argument('--emb_dim', type=int, default=600,
help='dimensionality of hidden units in GNNs (default: 600)')
parser.add_argument('--train_subset', action='store_true',
help='use 10% of the training set for training')
parser.add_argument('--batch_size', type=int, default=256,
help='input batch size for training (default: 256)')
parser.add_argument('--epochs', type=int, default=100,
help='number of epochs to train (default: 100)')
parser.add_argument('--num_workers', type=int, default=0,
help='number of workers (default: 0)')
parser.add_argument('--log_dir', type=str, default="",
help='tensorboard log directory. If not specified, '
'tensorboard will not be used.')
parser.add_argument('--checkpoint_dir', type=str, default='',
help='directory to save checkpoint')
parser.add_argument('--save_test_dir', type=str, default='',
help='directory to save test submission file')
parser = argparse.ArgumentParser(
description="GNN baselines on pcqm4m with DGL"
)
parser.add_argument(
"--seed", type=int, default=42, help="random seed to use (default: 42)"
)
parser.add_argument(
"--device",
type=int,
default=0,
help="which gpu to use if any (default: 0)",
)
parser.add_argument(
"--gnn",
type=str,
default="gin-virtual",
help="GNN to use, which can be from "
"[gin, gin-virtual, gcn, gcn-virtual] (default: gin-virtual)",
)
parser.add_argument(
"--graph_pooling",
type=str,
default="sum",
help="graph pooling strategy mean or sum (default: sum)",
)
parser.add_argument(
"--drop_ratio", type=float, default=0, help="dropout ratio (default: 0)"
)
parser.add_argument(
"--num_layers",
type=int,
default=5,
help="number of GNN message passing layers (default: 5)",
)
parser.add_argument(
"--emb_dim",
type=int,
default=600,
help="dimensionality of hidden units in GNNs (default: 600)",
)
parser.add_argument(
"--train_subset",
action="store_true",
help="use 10% of the training set for training",
)
parser.add_argument(
"--batch_size",
type=int,
default=256,
help="input batch size for training (default: 256)",
)
parser.add_argument(
"--epochs",
type=int,
default=100,
help="number of epochs to train (default: 100)",
)
parser.add_argument(
"--num_workers",
type=int,
default=0,
help="number of workers (default: 0)",
)
parser.add_argument(
"--log_dir",
type=str,
default="",
help="tensorboard log directory. If not specified, "
"tensorboard will not be used.",
)
parser.add_argument(
"--checkpoint_dir",
type=str,
default="",
help="directory to save checkpoint",
)
parser.add_argument(
"--save_test_dir",
type=str,
default="",
help="directory to save test submission file",
)
args = parser.parse_args()
print(args)
......@@ -137,7 +195,7 @@ def main():
device = torch.device("cpu")
### automatic dataloading and splitting
dataset = DglPCQM4MDataset(root='dataset/')
dataset = DglPCQM4MDataset(root="dataset/")
# split_idx['train'], split_idx['valid'], split_idx['test']
# separately gives a 1D int64 tensor
......@@ -148,47 +206,77 @@ def main():
if args.train_subset:
subset_ratio = 0.1
subset_idx = torch.randperm(len(split_idx["train"]))[:int(subset_ratio * len(split_idx["train"]))]
train_loader = DataLoader(dataset[split_idx["train"][subset_idx]], batch_size=args.batch_size, shuffle=True,
num_workers=args.num_workers, collate_fn=collate_dgl)
subset_idx = torch.randperm(len(split_idx["train"]))[
: int(subset_ratio * len(split_idx["train"]))
]
train_loader = DataLoader(
dataset[split_idx["train"][subset_idx]],
batch_size=args.batch_size,
shuffle=True,
num_workers=args.num_workers,
collate_fn=collate_dgl,
)
else:
train_loader = DataLoader(dataset[split_idx["train"]], batch_size=args.batch_size, shuffle=True,
num_workers=args.num_workers, collate_fn=collate_dgl)
valid_loader = DataLoader(dataset[split_idx["valid"]], batch_size=args.batch_size, shuffle=False,
num_workers=args.num_workers, collate_fn=collate_dgl)
if args.save_test_dir != '':
test_loader = DataLoader(dataset[split_idx["test"]], batch_size=args.batch_size, shuffle=False,
num_workers=args.num_workers, collate_fn=collate_dgl)
if args.checkpoint_dir != '':
train_loader = DataLoader(
dataset[split_idx["train"]],
batch_size=args.batch_size,
shuffle=True,
num_workers=args.num_workers,
collate_fn=collate_dgl,
)
valid_loader = DataLoader(
dataset[split_idx["valid"]],
batch_size=args.batch_size,
shuffle=False,
num_workers=args.num_workers,
collate_fn=collate_dgl,
)
if args.save_test_dir != "":
test_loader = DataLoader(
dataset[split_idx["test"]],
batch_size=args.batch_size,
shuffle=False,
num_workers=args.num_workers,
collate_fn=collate_dgl,
)
if args.checkpoint_dir != "":
os.makedirs(args.checkpoint_dir, exist_ok=True)
shared_params = {
'num_layers': args.num_layers,
'emb_dim': args.emb_dim,
'drop_ratio': args.drop_ratio,
'graph_pooling': args.graph_pooling
"num_layers": args.num_layers,
"emb_dim": args.emb_dim,
"drop_ratio": args.drop_ratio,
"graph_pooling": args.graph_pooling,
}
if args.gnn == 'gin':
model = GNN(gnn_type='gin', virtual_node=False, **shared_params).to(device)
elif args.gnn == 'gin-virtual':
model = GNN(gnn_type='gin', virtual_node=True, **shared_params).to(device)
elif args.gnn == 'gcn':
model = GNN(gnn_type='gcn', virtual_node=False, **shared_params).to(device)
elif args.gnn == 'gcn-virtual':
model = GNN(gnn_type='gcn', virtual_node=True, **shared_params).to(device)
if args.gnn == "gin":
model = GNN(gnn_type="gin", virtual_node=False, **shared_params).to(
device
)
elif args.gnn == "gin-virtual":
model = GNN(gnn_type="gin", virtual_node=True, **shared_params).to(
device
)
elif args.gnn == "gcn":
model = GNN(gnn_type="gcn", virtual_node=False, **shared_params).to(
device
)
elif args.gnn == "gcn-virtual":
model = GNN(gnn_type="gcn", virtual_node=True, **shared_params).to(
device
)
else:
raise ValueError('Invalid GNN type')
raise ValueError("Invalid GNN type")
num_params = sum(p.numel() for p in model.parameters())
print(f'#Params: {num_params}')
print(f"#Params: {num_params}")
optimizer = optim.Adam(model.parameters(), lr=0.001)
if args.log_dir != '':
if args.log_dir != "":
writer = SummaryWriter(log_dir=args.log_dir)
best_valid_mae = 1000
......@@ -201,40 +289,50 @@ def main():
for epoch in range(1, args.epochs + 1):
print("=====Epoch {}".format(epoch))
print('Training...')
print("Training...")
train_mae = train(model, device, train_loader, optimizer)
print('Evaluating...')
print("Evaluating...")
valid_mae = eval(model, device, valid_loader, evaluator)
print({'Train': train_mae, 'Validation': valid_mae})
print({"Train": train_mae, "Validation": valid_mae})
if args.log_dir != '':
writer.add_scalar('valid/mae', valid_mae, epoch)
writer.add_scalar('train/mae', train_mae, epoch)
if args.log_dir != "":
writer.add_scalar("valid/mae", valid_mae, epoch)
writer.add_scalar("train/mae", train_mae, epoch)
if valid_mae < best_valid_mae:
best_valid_mae = valid_mae
if args.checkpoint_dir != '':
print('Saving checkpoint...')
checkpoint = {'epoch': epoch, 'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(), 'best_val_mae': best_valid_mae,
'num_params': num_params}
torch.save(checkpoint, os.path.join(args.checkpoint_dir, 'checkpoint.pt'))
if args.save_test_dir != '':
print('Predicting on test data...')
if args.checkpoint_dir != "":
print("Saving checkpoint...")
checkpoint = {
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"scheduler_state_dict": scheduler.state_dict(),
"best_val_mae": best_valid_mae,
"num_params": num_params,
}
torch.save(
checkpoint,
os.path.join(args.checkpoint_dir, "checkpoint.pt"),
)
if args.save_test_dir != "":
print("Predicting on test data...")
y_pred = test(model, device, test_loader)
print('Saving test submission file...')
evaluator.save_test_submission({'y_pred': y_pred}, args.save_test_dir)
print("Saving test submission file...")
evaluator.save_test_submission(
{"y_pred": y_pred}, args.save_test_dir
)
scheduler.step()
print(f'Best validation MAE so far: {best_valid_mae}')
print(f"Best validation MAE so far: {best_valid_mae}")
if args.log_dir != '':
if args.log_dir != "":
writer.close()
if __name__ == "__main__":
main()
import argparse
import dgl
import numpy as np
import os
import random
import torch
import numpy as np
import torch
from gnn import GNN
from ogb.lsc import PCQM4MDataset, PCQM4MEvaluator
from ogb.utils import smiles2graph
from torch.utils.data import DataLoader
from tqdm import tqdm
from gnn import GNN
import dgl
def collate_dgl(graphs):
batched_graph = dgl.batch(graphs)
return batched_graph
def test(model, device, loader):
model.eval()
y_pred = []
for step, bg in enumerate(tqdm(loader, desc="Iteration")):
bg = bg.to(device)
x = bg.ndata.pop('feat')
edge_attr = bg.edata.pop('feat')
x = bg.ndata.pop("feat")
edge_attr = bg.edata.pop("feat")
with torch.no_grad():
pred = model(bg, x, edge_attr).view(-1, )
pred = model(bg, x, edge_attr).view(
-1,
)
y_pred.append(pred.detach().cpu())
......@@ -43,53 +47,99 @@ class OnTheFlyPCQMDataset(object):
self.smiles2graph = smiles2graph
def __getitem__(self, idx):
'''Get datapoint with index'''
"""Get datapoint with index"""
smiles, _ = self.smiles_list[idx]
graph = self.smiles2graph(smiles)
dgl_graph = dgl.graph((graph['edge_index'][0], graph['edge_index'][1]),
num_nodes=graph['num_nodes'])
dgl_graph.edata['feat'] = torch.from_numpy(graph['edge_feat']).to(torch.int64)
dgl_graph.ndata['feat'] = torch.from_numpy(graph['node_feat']).to(torch.int64)
dgl_graph = dgl.graph(
(graph["edge_index"][0], graph["edge_index"][1]),
num_nodes=graph["num_nodes"],
)
dgl_graph.edata["feat"] = torch.from_numpy(graph["edge_feat"]).to(
torch.int64
)
dgl_graph.ndata["feat"] = torch.from_numpy(graph["node_feat"]).to(
torch.int64
)
return dgl_graph
def __len__(self):
'''Length of the dataset
"""Length of the dataset
Returns
-------
int
Length of Dataset
'''
"""
return len(self.smiles_list)
def main():
# Training settings
parser = argparse.ArgumentParser(description='GNN baselines on pcqm4m with DGL')
parser.add_argument('--seed', type=int, default=42,
help='random seed to use (default: 42)')
parser.add_argument('--device', type=int, default=0,
help='which gpu to use if any (default: 0)')
parser.add_argument('--gnn', type=str, default='gin-virtual',
help='GNN to use, which can be from '
'[gin, gin-virtual, gcn, gcn-virtual] (default: gin-virtual)')
parser.add_argument('--graph_pooling', type=str, default='sum',
help='graph pooling strategy mean or sum (default: sum)')
parser.add_argument('--drop_ratio', type=float, default=0,
help='dropout ratio (default: 0)')
parser.add_argument('--num_layers', type=int, default=5,
help='number of GNN message passing layers (default: 5)')
parser.add_argument('--emb_dim', type=int, default=600,
help='dimensionality of hidden units in GNNs (default: 600)')
parser.add_argument('--batch_size', type=int, default=256,
help='input batch size for training (default: 256)')
parser.add_argument('--num_workers', type=int, default=0,
help='number of workers (default: 0)')
parser.add_argument('--checkpoint_dir', type=str, default='',
help='directory to save checkpoint')
parser.add_argument('--save_test_dir', type=str, default='',
help='directory to save test submission file')
parser = argparse.ArgumentParser(
description="GNN baselines on pcqm4m with DGL"
)
parser.add_argument(
"--seed", type=int, default=42, help="random seed to use (default: 42)"
)
parser.add_argument(
"--device",
type=int,
default=0,
help="which gpu to use if any (default: 0)",
)
parser.add_argument(
"--gnn",
type=str,
default="gin-virtual",
help="GNN to use, which can be from "
"[gin, gin-virtual, gcn, gcn-virtual] (default: gin-virtual)",
)
parser.add_argument(
"--graph_pooling",
type=str,
default="sum",
help="graph pooling strategy mean or sum (default: sum)",
)
parser.add_argument(
"--drop_ratio", type=float, default=0, help="dropout ratio (default: 0)"
)
parser.add_argument(
"--num_layers",
type=int,
default=5,
help="number of GNN message passing layers (default: 5)",
)
parser.add_argument(
"--emb_dim",
type=int,
default=600,
help="dimensionality of hidden units in GNNs (default: 600)",
)
parser.add_argument(
"--batch_size",
type=int,
default=256,
help="input batch size for training (default: 256)",
)
parser.add_argument(
"--num_workers",
type=int,
default=0,
help="number of workers (default: 0)",
)
parser.add_argument(
"--checkpoint_dir",
type=str,
default="",
help="directory to save checkpoint",
)
parser.add_argument(
"--save_test_dir",
type=str,
default="",
help="directory to save test submission file",
)
args = parser.parse_args()
print(args)
......@@ -106,50 +156,63 @@ def main():
### automatic data loading and splitting
### Read in the raw SMILES strings
smiles_dataset = PCQM4MDataset(root='dataset/', only_smiles=True)
smiles_dataset = PCQM4MDataset(root="dataset/", only_smiles=True)
split_idx = smiles_dataset.get_idx_split()
test_smiles_dataset = [smiles_dataset[i] for i in split_idx['test']]
test_smiles_dataset = [smiles_dataset[i] for i in split_idx["test"]]
onthefly_dataset = OnTheFlyPCQMDataset(test_smiles_dataset)
test_loader = DataLoader(onthefly_dataset, batch_size=args.batch_size, shuffle=False,
num_workers=args.num_workers, collate_fn=collate_dgl)
test_loader = DataLoader(
onthefly_dataset,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.num_workers,
collate_fn=collate_dgl,
)
### automatic evaluator.
evaluator = PCQM4MEvaluator()
shared_params = {
'num_layers': args.num_layers,
'emb_dim': args.emb_dim,
'drop_ratio': args.drop_ratio,
'graph_pooling': args.graph_pooling
"num_layers": args.num_layers,
"emb_dim": args.emb_dim,
"drop_ratio": args.drop_ratio,
"graph_pooling": args.graph_pooling,
}
if args.gnn == 'gin':
model = GNN(gnn_type='gin', virtual_node=False, **shared_params).to(device)
elif args.gnn == 'gin-virtual':
model = GNN(gnn_type='gin', virtual_node=True, **shared_params).to(device)
elif args.gnn == 'gcn':
model = GNN(gnn_type='gcn', virtual_node=False, **shared_params).to(device)
elif args.gnn == 'gcn-virtual':
model = GNN(gnn_type='gcn', virtual_node=True, **shared_params).to(device)
if args.gnn == "gin":
model = GNN(gnn_type="gin", virtual_node=False, **shared_params).to(
device
)
elif args.gnn == "gin-virtual":
model = GNN(gnn_type="gin", virtual_node=True, **shared_params).to(
device
)
elif args.gnn == "gcn":
model = GNN(gnn_type="gcn", virtual_node=False, **shared_params).to(
device
)
elif args.gnn == "gcn-virtual":
model = GNN(gnn_type="gcn", virtual_node=True, **shared_params).to(
device
)
else:
raise ValueError('Invalid GNN type')
raise ValueError("Invalid GNN type")
num_params = sum(p.numel() for p in model.parameters())
print(f'#Params: {num_params}')
print(f"#Params: {num_params}")
checkpoint_path = os.path.join(args.checkpoint_dir, 'checkpoint.pt')
checkpoint_path = os.path.join(args.checkpoint_dir, "checkpoint.pt")
if not os.path.exists(checkpoint_path):
raise RuntimeError(f'Checkpoint file not found at {checkpoint_path}')
raise RuntimeError(f"Checkpoint file not found at {checkpoint_path}")
## reading in checkpoint
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint['model_state_dict'])
model.load_state_dict(checkpoint["model_state_dict"])
print('Predicting on test data...')
print("Predicting on test data...")
y_pred = test(model, device, test_loader)
print('Saving test submission file...')
evaluator.save_test_submission({'y_pred': y_pred}, args.save_test_dir)
print("Saving test submission file...")
evaluator.save_test_submission({"y_pred": y_pred}, args.save_test_dir)
if __name__ == "__main__":
......
"""Graph builder from pandas dataframes"""
from collections import namedtuple
from pandas.api.types import is_numeric_dtype, is_categorical_dtype, is_categorical
from pandas.api.types import (
is_categorical,
is_categorical_dtype,
is_numeric_dtype,
)
import dgl
__all__ = ['PandasGraphBuilder']
__all__ = ["PandasGraphBuilder"]
def _series_to_tensor(series):
if is_categorical(series):
return torch.LongTensor(series.cat.codes.values.astype('int64'))
else: # numeric
return torch.LongTensor(series.cat.codes.values.astype("int64"))
else: # numeric
return torch.FloatTensor(series.values)
class PandasGraphBuilder(object):
"""Creates a heterogeneous graph from multiple pandas dataframes.
......@@ -60,25 +68,36 @@ class PandasGraphBuilder(object):
>>> g.num_edges('plays')
4
"""
def __init__(self):
self.entity_tables = {}
self.relation_tables = {}
self.entity_pk_to_name = {} # mapping from primary key name to entity name
self.entity_pk = {} # mapping from entity name to primary key
self.entity_key_map = {} # mapping from entity names to primary key values
self.entity_pk_to_name = (
{}
) # mapping from primary key name to entity name
self.entity_pk = {} # mapping from entity name to primary key
self.entity_key_map = (
{}
) # mapping from entity names to primary key values
self.num_nodes_per_type = {}
self.edges_per_relation = {}
self.relation_name_to_etype = {}
self.relation_src_key = {} # mapping from relation name to source key
self.relation_dst_key = {} # mapping from relation name to destination key
self.relation_src_key = {} # mapping from relation name to source key
self.relation_dst_key = (
{}
) # mapping from relation name to destination key
def add_entities(self, entity_table, primary_key, name):
entities = entity_table[primary_key].astype('category')
entities = entity_table[primary_key].astype("category")
if not (entities.value_counts() == 1).all():
raise ValueError('Different entity with the same primary key detected.')
raise ValueError(
"Different entity with the same primary key detected."
)
# preserve the category order in the original entity table
entities = entities.cat.reorder_categories(entity_table[primary_key].values)
entities = entities.cat.reorder_categories(
entity_table[primary_key].values
)
self.entity_pk_to_name[primary_key] = name
self.entity_pk[name] = primary_key
......@@ -86,33 +105,47 @@ class PandasGraphBuilder(object):
self.entity_key_map[name] = entities
self.entity_tables[name] = entity_table
def add_binary_relations(self, relation_table, source_key, destination_key, name):
src = relation_table[source_key].astype('category')
def add_binary_relations(
self, relation_table, source_key, destination_key, name
):
src = relation_table[source_key].astype("category")
src = src.cat.set_categories(
self.entity_key_map[self.entity_pk_to_name[source_key]].cat.categories)
dst = relation_table[destination_key].astype('category')
self.entity_key_map[
self.entity_pk_to_name[source_key]
].cat.categories
)
dst = relation_table[destination_key].astype("category")
dst = dst.cat.set_categories(
self.entity_key_map[self.entity_pk_to_name[destination_key]].cat.categories)
self.entity_key_map[
self.entity_pk_to_name[destination_key]
].cat.categories
)
if src.isnull().any():
raise ValueError(
'Some source entities in relation %s do not exist in entity %s.' %
(name, source_key))
"Some source entities in relation %s do not exist in entity %s."
% (name, source_key)
)
if dst.isnull().any():
raise ValueError(
'Some destination entities in relation %s do not exist in entity %s.' %
(name, destination_key))
"Some destination entities in relation %s do not exist in entity %s."
% (name, destination_key)
)
srctype = self.entity_pk_to_name[source_key]
dsttype = self.entity_pk_to_name[destination_key]
etype = (srctype, name, dsttype)
self.relation_name_to_etype[name] = etype
self.edges_per_relation[etype] = (src.cat.codes.values.astype('int64'), dst.cat.codes.values.astype('int64'))
self.edges_per_relation[etype] = (
src.cat.codes.values.astype("int64"),
dst.cat.codes.values.astype("int64"),
)
self.relation_tables[name] = relation_table
self.relation_src_key[name] = source_key
self.relation_dst_key[name] = destination_key
def build(self):
# Create heterograph
graph = dgl.heterograph(self.edges_per_relation, self.num_nodes_per_type)
graph = dgl.heterograph(
self.edges_per_relation, self.num_nodes_per_type
)
return graph
import torch
import dgl
import dask.dataframe as dd
import numpy as np
import scipy.sparse as ssp
import torch
import tqdm
import dask.dataframe as dd
import dgl
# This is the train-test split method most of the recommender system papers running on MovieLens
# takes. It essentially follows the intuition of "training on the past and predict the future".
# One can also change the threshold to make validation and test set take larger proportions.
def train_test_split_by_time(df, timestamp, user):
df['train_mask'] = np.ones((len(df),), dtype=np.bool)
df['val_mask'] = np.zeros((len(df),), dtype=np.bool)
df['test_mask'] = np.zeros((len(df),), dtype=np.bool)
df["train_mask"] = np.ones((len(df),), dtype=np.bool)
df["val_mask"] = np.zeros((len(df),), dtype=np.bool)
df["test_mask"] = np.zeros((len(df),), dtype=np.bool)
df = dd.from_pandas(df, npartitions=10)
def train_test_split(df):
df = df.sort_values([timestamp])
if df.shape[0] > 1:
......@@ -22,16 +25,25 @@ def train_test_split_by_time(df, timestamp, user):
df.iloc[-2, -3] = False
df.iloc[-2, -2] = True
return df
df = df.groupby(user, group_keys=False).apply(train_test_split).compute(scheduler='processes').sort_index()
df = (
df.groupby(user, group_keys=False)
.apply(train_test_split)
.compute(scheduler="processes")
.sort_index()
)
print(df[df[user] == df[user].unique()[0]].sort_values(timestamp))
return df['train_mask'].to_numpy().nonzero()[0], \
df['val_mask'].to_numpy().nonzero()[0], \
df['test_mask'].to_numpy().nonzero()[0]
return (
df["train_mask"].to_numpy().nonzero()[0],
df["val_mask"].to_numpy().nonzero()[0],
df["test_mask"].to_numpy().nonzero()[0],
)
def build_train_graph(g, train_indices, utype, itype, etype, etype_rev):
train_g = g.edge_subgraph(
{etype: train_indices, etype_rev: train_indices},
relabel_nodes=False)
{etype: train_indices, etype_rev: train_indices}, relabel_nodes=False
)
# copy features
for ntype in g.ntypes:
......@@ -39,10 +51,13 @@ def build_train_graph(g, train_indices, utype, itype, etype, etype_rev):
train_g.nodes[ntype].data[col] = data
for etype in g.etypes:
for col, data in g.edges[etype].data.items():
train_g.edges[etype].data[col] = data[train_g.edges[etype].data[dgl.EID]]
train_g.edges[etype].data[col] = data[
train_g.edges[etype].data[dgl.EID]
]
return train_g
def build_val_test_matrix(g, val_indices, test_indices, utype, itype, etype):
n_users = g.num_nodes(utype)
n_items = g.num_nodes(itype)
......@@ -52,11 +67,17 @@ def build_val_test_matrix(g, val_indices, test_indices, utype, itype, etype):
val_dst = val_dst.numpy()
test_src = test_src.numpy()
test_dst = test_dst.numpy()
val_matrix = ssp.coo_matrix((np.ones_like(val_src), (val_src, val_dst)), (n_users, n_items))
test_matrix = ssp.coo_matrix((np.ones_like(test_src), (test_src, test_dst)), (n_users, n_items))
val_matrix = ssp.coo_matrix(
(np.ones_like(val_src), (val_src, val_dst)), (n_users, n_items)
)
test_matrix = ssp.coo_matrix(
(np.ones_like(test_src), (test_src, test_dst)), (n_users, n_items)
)
return val_matrix, test_matrix
def linear_normalize(values):
return (values - values.min(0, keepdims=True)) / \
(values.max(0, keepdims=True) - values.min(0, keepdims=True))
return (values - values.min(0, keepdims=True)) / (
values.max(0, keepdims=True) - values.min(0, keepdims=True)
)
import argparse
import pickle
import numpy as np
import torch
import pickle
import dgl
import argparse
def prec(recommendations, ground_truth):
n_users, n_items = ground_truth.shape
......@@ -13,8 +16,11 @@ def prec(recommendations, ground_truth):
hit = relevance.any(axis=1).mean()
return hit
class LatestNNRecommender(object):
def __init__(self, user_ntype, item_ntype, user_to_item_etype, timestamp, batch_size):
def __init__(
self, user_ntype, item_ntype, user_to_item_etype, timestamp, batch_size
):
self.user_ntype = user_ntype
self.item_ntype = item_ntype
self.user_to_item_etype = user_to_item_etype
......@@ -27,19 +33,27 @@ class LatestNNRecommender(object):
"""
graph_slice = full_graph.edge_type_subgraph([self.user_to_item_etype])
n_users = full_graph.num_nodes(self.user_ntype)
latest_interactions = dgl.sampling.select_topk(graph_slice, 1, self.timestamp, edge_dir='out')
user, latest_items = latest_interactions.all_edges(form='uv', order='srcdst')
latest_interactions = dgl.sampling.select_topk(
graph_slice, 1, self.timestamp, edge_dir="out"
)
user, latest_items = latest_interactions.all_edges(
form="uv", order="srcdst"
)
# each user should have at least one "latest" interaction
assert torch.equal(user, torch.arange(n_users))
recommended_batches = []
user_batches = torch.arange(n_users).split(self.batch_size)
for user_batch in user_batches:
latest_item_batch = latest_items[user_batch].to(device=h_item.device)
latest_item_batch = latest_items[user_batch].to(
device=h_item.device
)
dist = h_item[latest_item_batch] @ h_item.t()
# exclude items that are already interacted
for i, u in enumerate(user_batch.tolist()):
interacted_items = full_graph.successors(u, etype=self.user_to_item_etype)
interacted_items = full_graph.successors(
u, etype=self.user_to_item_etype
)
dist[i, interacted_items] = -np.inf
recommended_batches.append(dist.topk(K, 1)[1])
......@@ -48,31 +62,33 @@ class LatestNNRecommender(object):
def evaluate_nn(dataset, h_item, k, batch_size):
g = dataset['train-graph']
val_matrix = dataset['val-matrix'].tocsr()
test_matrix = dataset['test-matrix'].tocsr()
item_texts = dataset['item-texts']
user_ntype = dataset['user-type']
item_ntype = dataset['item-type']
user_to_item_etype = dataset['user-to-item-type']
timestamp = dataset['timestamp-edge-column']
g = dataset["train-graph"]
val_matrix = dataset["val-matrix"].tocsr()
test_matrix = dataset["test-matrix"].tocsr()
item_texts = dataset["item-texts"]
user_ntype = dataset["user-type"]
item_ntype = dataset["item-type"]
user_to_item_etype = dataset["user-to-item-type"]
timestamp = dataset["timestamp-edge-column"]
rec_engine = LatestNNRecommender(
user_ntype, item_ntype, user_to_item_etype, timestamp, batch_size)
user_ntype, item_ntype, user_to_item_etype, timestamp, batch_size
)
recommendations = rec_engine.recommend(g, k, None, h_item).cpu().numpy()
return prec(recommendations, val_matrix)
if __name__ == '__main__':
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('dataset_path', type=str)
parser.add_argument('item_embedding_path', type=str)
parser.add_argument('-k', type=int, default=10)
parser.add_argument('--batch-size', type=int, default=32)
parser.add_argument("dataset_path", type=str)
parser.add_argument("item_embedding_path", type=str)
parser.add_argument("-k", type=int, default=10)
parser.add_argument("--batch-size", type=int, default=32)
args = parser.parse_args()
with open(args.dataset_path, 'rb') as f:
with open(args.dataset_path, "rb") as f:
dataset = pickle.load(f)
with open(args.item_embedding_path, 'rb') as f:
with open(args.item_embedding_path, "rb") as f:
emb = torch.FloatTensor(pickle.load(f))
print(evaluate_nn(dataset, emb, args.k, args.batch_size))
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
import dgl.nn.pytorch as dglnn
import dgl.function as fn
import dgl.nn.pytorch as dglnn
def disable_grad(module):
for param in module.parameters():
param.requires_grad = False
def _init_input_modules(g, ntype, textset, hidden_dims):
# We initialize the linear projections of each input feature ``x`` as
# follows:
......@@ -30,44 +33,50 @@ def _init_input_modules(g, ntype, textset, hidden_dims):
module_dict[column] = m
elif data.dtype == torch.int64:
assert data.ndim == 1
m = nn.Embedding(
data.max() + 2, hidden_dims, padding_idx=-1)
m = nn.Embedding(data.max() + 2, hidden_dims, padding_idx=-1)
nn.init.xavier_uniform_(m.weight)
module_dict[column] = m
if textset is not None:
for column, field in textset.items():
textlist, vocab, pad_var, batch_first = field
textlist, vocab, pad_var, batch_first = field
module_dict[column] = BagOfWords(vocab, hidden_dims)
return module_dict
class BagOfWords(nn.Module):
def __init__(self, vocab, hidden_dims):
super().__init__()
self.emb = nn.Embedding(
len(vocab.get_itos()), hidden_dims,
padding_idx=vocab.get_stoi()['<pad>'])
len(vocab.get_itos()),
hidden_dims,
padding_idx=vocab.get_stoi()["<pad>"],
)
nn.init.xavier_uniform_(self.emb.weight)
def forward(self, x, length):
return self.emb(x).sum(1) / length.unsqueeze(1).float()
class LinearProjector(nn.Module):
"""
Projects each input feature of the graph linearly and sums them up
"""
def __init__(self, full_graph, ntype, textset, hidden_dims):
super().__init__()
self.ntype = ntype
self.inputs = _init_input_modules(full_graph, ntype, textset, hidden_dims)
self.inputs = _init_input_modules(
full_graph, ntype, textset, hidden_dims
)
def forward(self, ndata):
projections = []
for feature, data in ndata.items():
if feature == dgl.NID or feature.endswith('__len'):
if feature == dgl.NID or feature.endswith("__len"):
# This is an additional feature indicating the length of the ``feature``
# column; we shouldn't process this.
continue
......@@ -75,7 +84,7 @@ class LinearProjector(nn.Module):
module = self.inputs[feature]
if isinstance(module, BagOfWords):
# Textual feature; find the length and pass it to the textual module.
length = ndata[feature + '__len']
length = ndata[feature + "__len"]
result = module(data, length)
else:
result = module(data)
......@@ -83,6 +92,7 @@ class LinearProjector(nn.Module):
return torch.stack(projections, 1).sum(1)
class WeightedSAGEConv(nn.Module):
def __init__(self, input_dims, hidden_dims, output_dims, act=F.relu):
super().__init__()
......@@ -94,7 +104,7 @@ class WeightedSAGEConv(nn.Module):
self.dropout = nn.Dropout(0.5)
def reset_parameters(self):
gain = nn.init.calculate_gain('relu')
gain = nn.init.calculate_gain("relu")
nn.init.xavier_uniform_(self.Q.weight, gain=gain)
nn.init.xavier_uniform_(self.W.weight, gain=gain)
nn.init.constant_(self.Q.bias, 0)
......@@ -108,18 +118,21 @@ class WeightedSAGEConv(nn.Module):
"""
h_src, h_dst = h
with g.local_scope():
g.srcdata['n'] = self.act(self.Q(self.dropout(h_src)))
g.edata['w'] = weights.float()
g.update_all(fn.u_mul_e('n', 'w', 'm'), fn.sum('m', 'n'))
g.update_all(fn.copy_e('w', 'm'), fn.sum('m', 'ws'))
n = g.dstdata['n']
ws = g.dstdata['ws'].unsqueeze(1).clamp(min=1)
g.srcdata["n"] = self.act(self.Q(self.dropout(h_src)))
g.edata["w"] = weights.float()
g.update_all(fn.u_mul_e("n", "w", "m"), fn.sum("m", "n"))
g.update_all(fn.copy_e("w", "m"), fn.sum("m", "ws"))
n = g.dstdata["n"]
ws = g.dstdata["ws"].unsqueeze(1).clamp(min=1)
z = self.act(self.W(self.dropout(torch.cat([n / ws, h_dst], 1))))
z_norm = z.norm(2, 1, keepdim=True)
z_norm = torch.where(z_norm == 0, torch.tensor(1.).to(z_norm), z_norm)
z_norm = torch.where(
z_norm == 0, torch.tensor(1.0).to(z_norm), z_norm
)
z = z / z_norm
return z
class SAGENet(nn.Module):
def __init__(self, hidden_dims, n_layers):
"""
......@@ -133,14 +146,17 @@ class SAGENet(nn.Module):
self.convs = nn.ModuleList()
for _ in range(n_layers):
self.convs.append(WeightedSAGEConv(hidden_dims, hidden_dims, hidden_dims))
self.convs.append(
WeightedSAGEConv(hidden_dims, hidden_dims, hidden_dims)
)
def forward(self, blocks, h):
for layer, block in zip(self.convs, blocks):
h_dst = h[:block.num_nodes('DST/' + block.ntypes[0])]
h = layer(block, (h, h_dst), block.edata['weights'])
h_dst = h[: block.num_nodes("DST/" + block.ntypes[0])]
h = layer(block, (h, h_dst), block.edata["weights"])
return h
class ItemToItemScorer(nn.Module):
def __init__(self, full_graph, ntype):
super().__init__()
......@@ -151,7 +167,7 @@ class ItemToItemScorer(nn.Module):
def _add_bias(self, edges):
bias_src = self.bias[edges.src[dgl.NID]]
bias_dst = self.bias[edges.dst[dgl.NID]]
return {'s': edges.data['s'] + bias_src + bias_dst}
return {"s": edges.data["s"] + bias_src + bias_dst}
def forward(self, item_item_graph, h):
"""
......@@ -159,8 +175,8 @@ class ItemToItemScorer(nn.Module):
h : hidden state of every node
"""
with item_item_graph.local_scope():
item_item_graph.ndata['h'] = h
item_item_graph.apply_edges(fn.u_dot_v('h', 'h', 's'))
item_item_graph.ndata["h"] = h
item_item_graph.apply_edges(fn.u_dot_v("h", "h", "s"))
item_item_graph.apply_edges(self._add_bias)
pair_score = item_item_graph.edata['s']
pair_score = item_item_graph.edata["s"]
return pair_score
import pickle
import argparse
import os
import pickle
import evaluation
import layers
import numpy as np
import sampler as sampler_module
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchtext
import dgl
import os
import tqdm
import layers
import sampler as sampler_module
import evaluation
from torch.utils.data import DataLoader
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
import dgl
class PinSAGEModel(nn.Module):
def __init__(self, full_graph, ntype, textsets, hidden_dims, n_layers):
super().__init__()
self.proj = layers.LinearProjector(full_graph, ntype, textsets, hidden_dims)
self.proj = layers.LinearProjector(
full_graph, ntype, textsets, hidden_dims
)
self.sage = layers.SAGENet(hidden_dims, n_layers)
self.scorer = layers.ItemToItemScorer(full_graph, ntype)
......@@ -36,19 +40,22 @@ class PinSAGEModel(nn.Module):
# add to the item embedding itself
h_item = h_item + item_emb(blocks[0].srcdata[dgl.NID].cpu()).to(h_item)
h_item_dst = h_item_dst + item_emb(blocks[-1].dstdata[dgl.NID].cpu()).to(h_item_dst)
h_item_dst = h_item_dst + item_emb(
blocks[-1].dstdata[dgl.NID].cpu()
).to(h_item_dst)
return h_item_dst + self.sage(blocks, h_item)
def train(dataset, args):
g = dataset['train-graph']
val_matrix = dataset['val-matrix'].tocsr()
test_matrix = dataset['test-matrix'].tocsr()
item_texts = dataset['item-texts']
user_ntype = dataset['user-type']
item_ntype = dataset['item-type']
user_to_item_etype = dataset['user-to-item-type']
timestamp = dataset['timestamp-edge-column']
g = dataset["train-graph"]
val_matrix = dataset["val-matrix"].tocsr()
test_matrix = dataset["test-matrix"].tocsr()
item_texts = dataset["item-texts"]
user_ntype = dataset["user-type"]
item_ntype = dataset["item-type"]
user_to_item_etype = dataset["user-to-item-type"]
timestamp = dataset["timestamp-edge-column"]
device = torch.device(args.device)
......@@ -64,31 +71,53 @@ def train(dataset, args):
l = tokenizer(item_texts[key][i].lower())
textlist.append(l)
for key, field in item_texts.items():
vocab2 = build_vocab_from_iterator(textlist, specials=["<unk>","<pad>"])
textset[key] = (textlist, vocab2, vocab2.get_stoi()['<pad>'], batch_first)
vocab2 = build_vocab_from_iterator(
textlist, specials=["<unk>", "<pad>"]
)
textset[key] = (
textlist,
vocab2,
vocab2.get_stoi()["<pad>"],
batch_first,
)
# Sampler
batch_sampler = sampler_module.ItemToItemBatchSampler(
g, user_ntype, item_ntype, args.batch_size)
g, user_ntype, item_ntype, args.batch_size
)
neighbor_sampler = sampler_module.NeighborSampler(
g, user_ntype, item_ntype, args.random_walk_length,
args.random_walk_restart_prob, args.num_random_walks, args.num_neighbors,
args.num_layers)
collator = sampler_module.PinSAGECollator(neighbor_sampler, g, item_ntype, textset)
g,
user_ntype,
item_ntype,
args.random_walk_length,
args.random_walk_restart_prob,
args.num_random_walks,
args.num_neighbors,
args.num_layers,
)
collator = sampler_module.PinSAGECollator(
neighbor_sampler, g, item_ntype, textset
)
dataloader = DataLoader(
batch_sampler,
collate_fn=collator.collate_train,
num_workers=args.num_workers)
num_workers=args.num_workers,
)
dataloader_test = DataLoader(
torch.arange(g.num_nodes(item_ntype)),
batch_size=args.batch_size,
collate_fn=collator.collate_test,
num_workers=args.num_workers)
num_workers=args.num_workers,
)
dataloader_it = iter(dataloader)
# Model
model = PinSAGEModel(g, item_ntype, textset, args.hidden_dims, args.num_layers).to(device)
item_emb = nn.Embedding(g.num_nodes(item_ntype), args.hidden_dims, sparse=True)
model = PinSAGEModel(
g, item_ntype, textset, args.hidden_dims, args.num_layers
).to(device)
item_emb = nn.Embedding(
g.num_nodes(item_ntype), args.hidden_dims, sparse=True
)
# Optimizer
opt = torch.optim.Adam(model.parameters(), lr=args.lr)
opt_emb = torch.optim.SparseAdam(item_emb.parameters(), lr=args.lr)
......@@ -114,7 +143,9 @@ def train(dataset, args):
# Evaluate
model.eval()
with torch.no_grad():
item_batches = torch.arange(g.num_nodes(item_ntype)).split(args.batch_size)
item_batches = torch.arange(g.num_nodes(item_ntype)).split(
args.batch_size
)
h_item_batches = []
for blocks in tqdm.tqdm(dataloader_test):
for i in range(len(blocks)):
......@@ -123,32 +154,37 @@ def train(dataset, args):
h_item_batches.append(model.get_repr(blocks, item_emb))
h_item = torch.cat(h_item_batches, 0)
print(evaluation.evaluate_nn(dataset, h_item, args.k, args.batch_size))
print(
evaluation.evaluate_nn(dataset, h_item, args.k, args.batch_size)
)
if __name__ == '__main__':
if __name__ == "__main__":
# Arguments
parser = argparse.ArgumentParser()
parser.add_argument('dataset_path', type=str)
parser.add_argument('--random-walk-length', type=int, default=2)
parser.add_argument('--random-walk-restart-prob', type=float, default=0.5)
parser.add_argument('--num-random-walks', type=int, default=10)
parser.add_argument('--num-neighbors', type=int, default=3)
parser.add_argument('--num-layers', type=int, default=2)
parser.add_argument('--hidden-dims', type=int, default=16)
parser.add_argument('--batch-size', type=int, default=32)
parser.add_argument('--device', type=str, default='cpu') # can also be "cuda:0"
parser.add_argument('--num-epochs', type=int, default=1)
parser.add_argument('--batches-per-epoch', type=int, default=20000)
parser.add_argument('--num-workers', type=int, default=0)
parser.add_argument('--lr', type=float, default=3e-5)
parser.add_argument('-k', type=int, default=10)
parser.add_argument("dataset_path", type=str)
parser.add_argument("--random-walk-length", type=int, default=2)
parser.add_argument("--random-walk-restart-prob", type=float, default=0.5)
parser.add_argument("--num-random-walks", type=int, default=10)
parser.add_argument("--num-neighbors", type=int, default=3)
parser.add_argument("--num-layers", type=int, default=2)
parser.add_argument("--hidden-dims", type=int, default=16)
parser.add_argument("--batch-size", type=int, default=32)
parser.add_argument(
"--device", type=str, default="cpu"
) # can also be "cuda:0"
parser.add_argument("--num-epochs", type=int, default=1)
parser.add_argument("--batches-per-epoch", type=int, default=20000)
parser.add_argument("--num-workers", type=int, default=0)
parser.add_argument("--lr", type=float, default=3e-5)
parser.add_argument("-k", type=int, default=10)
args = parser.parse_args()
# Load dataset
data_info_path = os.path.join(args.dataset_path, 'data.pkl')
with open(data_info_path, 'rb') as f:
data_info_path = os.path.join(args.dataset_path, "data.pkl")
with open(data_info_path, "rb") as f:
dataset = pickle.load(f)
train_g_path = os.path.join(args.dataset_path, 'train_g.bin')
train_g_path = os.path.join(args.dataset_path, "train_g.bin")
g_list, _ = dgl.load_graphs(train_g_path)
dataset['train-graph'] = g_list[0]
dataset["train-graph"] = g_list[0]
train(dataset, args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment