Unverified Commit be8763fa authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Black auto fix. (#4679)


Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent eae6ce2a
...@@ -3,24 +3,26 @@ Paper: https://arxiv.org/abs/1703.06103 ...@@ -3,24 +3,26 @@ Paper: https://arxiv.org/abs/1703.06103
Reference Code: https://github.com/tkipf/relational-gcn Reference Code: https://github.com/tkipf/relational-gcn
""" """
import argparse import argparse
import numpy as np
import time import time
import numpy as np
import torch as th import torch as th
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from dgl.data.rdf import AIFBDataset, MUTAGDataset, BGSDataset, AMDataset
from model import EntityClassify_HeteroAPI from model import EntityClassify_HeteroAPI
from dgl.data.rdf import AIFBDataset, AMDataset, BGSDataset, MUTAGDataset
def main(args): def main(args):
# load graph data # load graph data
if args.dataset == 'aifb': if args.dataset == "aifb":
dataset = AIFBDataset() dataset = AIFBDataset()
elif args.dataset == 'mutag': elif args.dataset == "mutag":
dataset = MUTAGDataset() dataset = MUTAGDataset()
elif args.dataset == 'bgs': elif args.dataset == "bgs":
dataset = BGSDataset() dataset = BGSDataset()
elif args.dataset == 'am': elif args.dataset == "am":
dataset = AMDataset() dataset = AMDataset()
else: else:
raise ValueError() raise ValueError()
...@@ -28,11 +30,11 @@ def main(args): ...@@ -28,11 +30,11 @@ def main(args):
g = dataset[0] g = dataset[0]
category = dataset.predict_category category = dataset.predict_category
num_classes = dataset.num_classes num_classes = dataset.num_classes
train_mask = g.nodes[category].data.pop('train_mask') train_mask = g.nodes[category].data.pop("train_mask")
test_mask = g.nodes[category].data.pop('test_mask') test_mask = g.nodes[category].data.pop("test_mask")
train_idx = th.nonzero(train_mask, as_tuple=False).squeeze() train_idx = th.nonzero(train_mask, as_tuple=False).squeeze()
test_idx = th.nonzero(test_mask, as_tuple=False).squeeze() test_idx = th.nonzero(test_mask, as_tuple=False).squeeze()
labels = g.nodes[category].data.pop('labels') labels = g.nodes[category].data.pop("labels")
category_id = len(g.ntypes) category_id = len(g.ntypes)
for i, ntype in enumerate(g.ntypes): for i, ntype in enumerate(g.ntypes):
if ntype == category: if ntype == category:
...@@ -40,8 +42,8 @@ def main(args): ...@@ -40,8 +42,8 @@ def main(args):
# split dataset into train, validate, test # split dataset into train, validate, test
if args.validation: if args.validation:
val_idx = train_idx[:len(train_idx) // 5] val_idx = train_idx[: len(train_idx) // 5]
train_idx = train_idx[len(train_idx) // 5:] train_idx = train_idx[len(train_idx) // 5 :]
else: else:
val_idx = train_idx val_idx = train_idx
...@@ -49,25 +51,29 @@ def main(args): ...@@ -49,25 +51,29 @@ def main(args):
use_cuda = args.gpu >= 0 and th.cuda.is_available() use_cuda = args.gpu >= 0 and th.cuda.is_available()
if use_cuda: if use_cuda:
th.cuda.set_device(args.gpu) th.cuda.set_device(args.gpu)
g = g.to('cuda:%d' % args.gpu) g = g.to("cuda:%d" % args.gpu)
labels = labels.cuda() labels = labels.cuda()
train_idx = train_idx.cuda() train_idx = train_idx.cuda()
test_idx = test_idx.cuda() test_idx = test_idx.cuda()
# create model # create model
model = EntityClassify_HeteroAPI(g, model = EntityClassify_HeteroAPI(
args.n_hidden, g,
num_classes, args.n_hidden,
num_bases=args.n_bases, num_classes,
num_hidden_layers=args.n_layers - 2, num_bases=args.n_bases,
dropout=args.dropout, num_hidden_layers=args.n_layers - 2,
use_self_loop=args.use_self_loop) dropout=args.dropout,
use_self_loop=args.use_self_loop,
)
if use_cuda: if use_cuda:
model.cuda() model.cuda()
# optimizer # optimizer
optimizer = th.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.l2norm) optimizer = th.optim.Adam(
model.parameters(), lr=args.lr, weight_decay=args.l2norm
)
# training loop # training loop
print("start training...") print("start training...")
...@@ -83,11 +89,23 @@ def main(args): ...@@ -83,11 +89,23 @@ def main(args):
t1 = time.time() t1 = time.time()
dur.append(t1 - t0) dur.append(t1 - t0)
train_acc = th.sum(logits[train_idx].argmax(dim=1) == labels[train_idx]).item() / len(train_idx) train_acc = th.sum(
logits[train_idx].argmax(dim=1) == labels[train_idx]
).item() / len(train_idx)
val_loss = F.cross_entropy(logits[val_idx], labels[val_idx]) val_loss = F.cross_entropy(logits[val_idx], labels[val_idx])
val_acc = th.sum(logits[val_idx].argmax(dim=1) == labels[val_idx]).item() / len(val_idx) val_acc = th.sum(
print("Epoch {:05d} | Train Acc: {:.4f} | Train Loss: {:.4f} | Valid Acc: {:.4f} | Valid loss: {:.4f} | Time: {:.4f}". logits[val_idx].argmax(dim=1) == labels[val_idx]
format(epoch, train_acc, loss.item(), val_acc, val_loss.item(), np.average(dur))) ).item() / len(val_idx)
print(
"Epoch {:05d} | Train Acc: {:.4f} | Train Loss: {:.4f} | Valid Acc: {:.4f} | Valid loss: {:.4f} | Time: {:.4f}".format(
epoch,
train_acc,
loss.item(),
val_acc,
val_loss.item(),
np.average(dur),
)
)
print() print()
if args.model_path is not None: if args.model_path is not None:
th.save(model.state_dict(), args.model_path) th.save(model.state_dict(), args.model_path)
...@@ -95,37 +113,59 @@ def main(args): ...@@ -95,37 +113,59 @@ def main(args):
model.eval() model.eval()
logits = model.forward()[category] logits = model.forward()[category]
test_loss = F.cross_entropy(logits[test_idx], labels[test_idx]) test_loss = F.cross_entropy(logits[test_idx], labels[test_idx])
test_acc = th.sum(logits[test_idx].argmax(dim=1) == labels[test_idx]).item() / len(test_idx) test_acc = th.sum(
print("Test Acc: {:.4f} | Test loss: {:.4f}".format(test_acc, test_loss.item())) logits[test_idx].argmax(dim=1) == labels[test_idx]
).item() / len(test_idx)
print(
"Test Acc: {:.4f} | Test loss: {:.4f}".format(
test_acc, test_loss.item()
)
)
print() print()
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='RGCN') if __name__ == "__main__":
parser.add_argument("--dropout", type=float, default=0, parser = argparse.ArgumentParser(description="RGCN")
help="dropout probability") parser.add_argument(
parser.add_argument("--n-hidden", type=int, default=16, "--dropout", type=float, default=0, help="dropout probability"
help="number of hidden units") )
parser.add_argument("--gpu", type=int, default=-1, parser.add_argument(
help="gpu") "--n-hidden", type=int, default=16, help="number of hidden units"
parser.add_argument("--lr", type=float, default=1e-2, )
help="learning rate") parser.add_argument("--gpu", type=int, default=-1, help="gpu")
parser.add_argument("--n-bases", type=int, default=-1, parser.add_argument("--lr", type=float, default=1e-2, help="learning rate")
help="number of filter weight matrices, default: -1 [use all]") parser.add_argument(
parser.add_argument("--n-layers", type=int, default=2, "--n-bases",
help="number of propagation rounds") type=int,
parser.add_argument("-e", "--n-epochs", type=int, default=50, default=-1,
help="number of training epochs") help="number of filter weight matrices, default: -1 [use all]",
parser.add_argument("-d", "--dataset", type=str, required=True, )
help="dataset to use") parser.add_argument(
parser.add_argument("--model_path", type=str, default=None, "--n-layers", type=int, default=2, help="number of propagation rounds"
help='path for save the model') )
parser.add_argument("--l2norm", type=float, default=0, parser.add_argument(
help="l2 norm coef") "-e",
parser.add_argument("--use-self-loop", default=False, action='store_true', "--n-epochs",
help="include self feature as a special relation") type=int,
default=50,
help="number of training epochs",
)
parser.add_argument(
"-d", "--dataset", type=str, required=True, help="dataset to use"
)
parser.add_argument(
"--model_path", type=str, default=None, help="path for save the model"
)
parser.add_argument("--l2norm", type=float, default=0, help="l2 norm coef")
parser.add_argument(
"--use-self-loop",
default=False,
action="store_true",
help="include self feature as a special relation",
)
fp = parser.add_mutually_exclusive_group(required=False) fp = parser.add_mutually_exclusive_group(required=False)
fp.add_argument('--validation', dest='validation', action='store_true') fp.add_argument("--validation", dest="validation", action="store_true")
fp.add_argument('--testing', dest='validation', action='store_false') fp.add_argument("--testing", dest="validation", action="store_false")
parser.set_defaults(validation=True) parser.set_defaults(validation=True)
args = parser.parse_args() args = parser.parse_args()
......
...@@ -4,14 +4,16 @@ Reference Code: https://github.com/tkipf/relational-gcn ...@@ -4,14 +4,16 @@ Reference Code: https://github.com/tkipf/relational-gcn
""" """
import argparse import argparse
import itertools import itertools
import numpy as np
import time import time
import numpy as np
import torch as th import torch as th
import torch.nn.functional as F import torch.nn.functional as F
from model import EntityClassify, RelGraphEmbed
import dgl import dgl
from dgl.data.rdf import AIFBDataset, MUTAGDataset, BGSDataset, AMDataset from dgl.data.rdf import AIFBDataset, AMDataset, BGSDataset, MUTAGDataset
from model import EntityClassify, RelGraphEmbed
def extract_embed(node_embed, input_nodes): def extract_embed(node_embed, input_nodes):
emb = {} emb = {}
...@@ -20,6 +22,7 @@ def extract_embed(node_embed, input_nodes): ...@@ -20,6 +22,7 @@ def extract_embed(node_embed, input_nodes):
emb[ntype] = node_embed[ntype][nid] emb[ntype] = node_embed[ntype][nid]
return emb return emb
def evaluate(model, loader, node_embed, labels, category, device): def evaluate(model, loader, node_embed, labels, category, device):
model.eval() model.eval()
total_loss = 0 total_loss = 0
...@@ -39,22 +42,23 @@ def evaluate(model, loader, node_embed, labels, category, device): ...@@ -39,22 +42,23 @@ def evaluate(model, loader, node_embed, labels, category, device):
count += len(seeds) count += len(seeds)
return total_loss / count, total_acc / count return total_loss / count, total_acc / count
def main(args): def main(args):
# check cuda # check cuda
device = 'cpu' device = "cpu"
use_cuda = args.gpu >= 0 and th.cuda.is_available() use_cuda = args.gpu >= 0 and th.cuda.is_available()
if use_cuda: if use_cuda:
th.cuda.set_device(args.gpu) th.cuda.set_device(args.gpu)
device = 'cuda:%d' % args.gpu device = "cuda:%d" % args.gpu
# load graph data # load graph data
if args.dataset == 'aifb': if args.dataset == "aifb":
dataset = AIFBDataset() dataset = AIFBDataset()
elif args.dataset == 'mutag': elif args.dataset == "mutag":
dataset = MUTAGDataset() dataset = MUTAGDataset()
elif args.dataset == 'bgs': elif args.dataset == "bgs":
dataset = BGSDataset() dataset = BGSDataset()
elif args.dataset == 'am': elif args.dataset == "am":
dataset = AMDataset() dataset = AMDataset()
else: else:
raise ValueError() raise ValueError()
...@@ -62,16 +66,16 @@ def main(args): ...@@ -62,16 +66,16 @@ def main(args):
g = dataset[0] g = dataset[0]
category = dataset.predict_category category = dataset.predict_category
num_classes = dataset.num_classes num_classes = dataset.num_classes
train_mask = g.nodes[category].data.pop('train_mask') train_mask = g.nodes[category].data.pop("train_mask")
test_mask = g.nodes[category].data.pop('test_mask') test_mask = g.nodes[category].data.pop("test_mask")
train_idx = th.nonzero(train_mask, as_tuple=False).squeeze() train_idx = th.nonzero(train_mask, as_tuple=False).squeeze()
test_idx = th.nonzero(test_mask, as_tuple=False).squeeze() test_idx = th.nonzero(test_mask, as_tuple=False).squeeze()
labels = g.nodes[category].data.pop('labels') labels = g.nodes[category].data.pop("labels")
# split dataset into train, validate, test # split dataset into train, validate, test
if args.validation: if args.validation:
val_idx = train_idx[:len(train_idx) // 5] val_idx = train_idx[: len(train_idx) // 5]
train_idx = train_idx[len(train_idx) // 5:] train_idx = train_idx[len(train_idx) // 5 :]
else: else:
val_idx = train_idx val_idx = train_idx
...@@ -84,29 +88,45 @@ def main(args): ...@@ -84,29 +88,45 @@ def main(args):
node_embed = embed_layer() node_embed = embed_layer()
# create model # create model
model = EntityClassify(g, model = EntityClassify(
args.n_hidden, g,
num_classes, args.n_hidden,
num_bases=args.n_bases, num_classes,
num_hidden_layers=args.n_layers - 2, num_bases=args.n_bases,
dropout=args.dropout, num_hidden_layers=args.n_layers - 2,
use_self_loop=args.use_self_loop) dropout=args.dropout,
use_self_loop=args.use_self_loop,
)
if use_cuda: if use_cuda:
model.cuda() model.cuda()
# train sampler # train sampler
sampler = dgl.dataloading.MultiLayerNeighborSampler([args.fanout] * args.n_layers) sampler = dgl.dataloading.MultiLayerNeighborSampler(
[args.fanout] * args.n_layers
)
loader = dgl.dataloading.DataLoader( loader = dgl.dataloading.DataLoader(
g, {category: train_idx}, sampler, g,
batch_size=args.batch_size, shuffle=True, num_workers=0) {category: train_idx},
sampler,
batch_size=args.batch_size,
shuffle=True,
num_workers=0,
)
# validation sampler # validation sampler
# we do not use full neighbor to save computation resources # we do not use full neighbor to save computation resources
val_sampler = dgl.dataloading.MultiLayerNeighborSampler([args.fanout] * args.n_layers) val_sampler = dgl.dataloading.MultiLayerNeighborSampler(
[args.fanout] * args.n_layers
)
val_loader = dgl.dataloading.DataLoader( val_loader = dgl.dataloading.DataLoader(
g, {category: val_idx}, val_sampler, g,
batch_size=args.batch_size, shuffle=True, num_workers=0) {category: val_idx},
val_sampler,
batch_size=args.batch_size,
shuffle=True,
num_workers=0,
)
# optimizer # optimizer
all_params = itertools.chain(model.parameters(), embed_layer.parameters()) all_params = itertools.chain(model.parameters(), embed_layer.parameters())
...@@ -123,12 +143,14 @@ def main(args): ...@@ -123,12 +143,14 @@ def main(args):
for i, (input_nodes, seeds, blocks) in enumerate(loader): for i, (input_nodes, seeds, blocks) in enumerate(loader):
blocks = [blk.to(device) for blk in blocks] blocks = [blk.to(device) for blk in blocks]
seeds = seeds[category] # we only predict the nodes with type "category" seeds = seeds[
category
] # we only predict the nodes with type "category"
batch_tic = time.time() batch_tic = time.time()
emb = extract_embed(node_embed, input_nodes) emb = extract_embed(node_embed, input_nodes)
lbl = labels[seeds] lbl = labels[seeds]
if use_cuda: if use_cuda:
emb = {k : e.cuda() for k, e in emb.items()} emb = {k: e.cuda() for k, e in emb.items()}
lbl = lbl.cuda() lbl = lbl.cuda()
logits = model(emb, blocks)[category] logits = model(emb, blocks)[category]
loss = F.cross_entropy(logits, lbl) loss = F.cross_entropy(logits, lbl)
...@@ -136,63 +158,96 @@ def main(args): ...@@ -136,63 +158,96 @@ def main(args):
optimizer.step() optimizer.step()
train_acc = th.sum(logits.argmax(dim=1) == lbl).item() / len(seeds) train_acc = th.sum(logits.argmax(dim=1) == lbl).item() / len(seeds)
print("Epoch {:05d} | Batch {:03d} | Train Acc: {:.4f} | Train Loss: {:.4f} | Time: {:.4f}". print(
format(epoch, i, train_acc, loss.item(), time.time() - batch_tic)) "Epoch {:05d} | Batch {:03d} | Train Acc: {:.4f} | Train Loss: {:.4f} | Time: {:.4f}".format(
epoch, i, train_acc, loss.item(), time.time() - batch_tic
)
)
if epoch > 3: if epoch > 3:
dur.append(time.time() - t0) dur.append(time.time() - t0)
val_loss, val_acc = evaluate(model, val_loader, node_embed, labels, category, device) val_loss, val_acc = evaluate(
print("Epoch {:05d} | Valid Acc: {:.4f} | Valid loss: {:.4f} | Time: {:.4f}". model, val_loader, node_embed, labels, category, device
format(epoch, val_acc, val_loss, np.average(dur))) )
print(
"Epoch {:05d} | Valid Acc: {:.4f} | Valid loss: {:.4f} | Time: {:.4f}".format(
epoch, val_acc, val_loss, np.average(dur)
)
)
print() print()
if args.model_path is not None: if args.model_path is not None:
th.save(model.state_dict(), args.model_path) th.save(model.state_dict(), args.model_path)
output = model.inference( output = model.inference(
g, args.batch_size, 'cuda' if use_cuda else 'cpu', 0, node_embed) g, args.batch_size, "cuda" if use_cuda else "cpu", 0, node_embed
)
test_pred = output[category][test_idx] test_pred = output[category][test_idx]
test_labels = labels[test_idx].to(test_pred.device) test_labels = labels[test_idx].to(test_pred.device)
test_acc = (test_pred.argmax(1) == test_labels).float().mean() test_acc = (test_pred.argmax(1) == test_labels).float().mean()
print("Test Acc: {:.4f}".format(test_acc)) print("Test Acc: {:.4f}".format(test_acc))
print() print()
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='RGCN') if __name__ == "__main__":
parser.add_argument("--dropout", type=float, default=0, parser = argparse.ArgumentParser(description="RGCN")
help="dropout probability") parser.add_argument(
parser.add_argument("--n-hidden", type=int, default=16, "--dropout", type=float, default=0, help="dropout probability"
help="number of hidden units") )
parser.add_argument("--gpu", type=int, default=-1, parser.add_argument(
help="gpu") "--n-hidden", type=int, default=16, help="number of hidden units"
parser.add_argument("--lr", type=float, default=1e-2, )
help="learning rate") parser.add_argument("--gpu", type=int, default=-1, help="gpu")
parser.add_argument("--n-bases", type=int, default=-1, parser.add_argument("--lr", type=float, default=1e-2, help="learning rate")
help="number of filter weight matrices, default: -1 [use all]") parser.add_argument(
parser.add_argument("--n-layers", type=int, default=2, "--n-bases",
help="number of propagation rounds") type=int,
parser.add_argument("-e", "--n-epochs", type=int, default=20, default=-1,
help="number of training epochs") help="number of filter weight matrices, default: -1 [use all]",
parser.add_argument("-d", "--dataset", type=str, required=True, )
help="dataset to use") parser.add_argument(
parser.add_argument("--model_path", type=str, default=None, "--n-layers", type=int, default=2, help="number of propagation rounds"
help='path for save the model') )
parser.add_argument("--l2norm", type=float, default=0, parser.add_argument(
help="l2 norm coef") "-e",
parser.add_argument("--use-self-loop", default=False, action='store_true', "--n-epochs",
help="include self feature as a special relation") type=int,
parser.add_argument("--batch-size", type=int, default=100, default=20,
help="Mini-batch size. If -1, use full graph training.") help="number of training epochs",
parser.add_argument("--fanout", type=int, default=4, )
help="Fan-out of neighbor sampling.") parser.add_argument(
parser.add_argument('--data-cpu', action='store_true', "-d", "--dataset", type=str, required=True, help="dataset to use"
help="By default the script puts all node features and labels " )
"on GPU when using it to save time for data copy. This may " parser.add_argument(
"be undesired if they cannot fit in GPU memory at once. " "--model_path", type=str, default=None, help="path for save the model"
"This flag disables that.") )
parser.add_argument("--l2norm", type=float, default=0, help="l2 norm coef")
parser.add_argument(
"--use-self-loop",
default=False,
action="store_true",
help="include self feature as a special relation",
)
parser.add_argument(
"--batch-size",
type=int,
default=100,
help="Mini-batch size. If -1, use full graph training.",
)
parser.add_argument(
"--fanout", type=int, default=4, help="Fan-out of neighbor sampling."
)
parser.add_argument(
"--data-cpu",
action="store_true",
help="By default the script puts all node features and labels "
"on GPU when using it to save time for data copy. This may "
"be undesired if they cannot fit in GPU memory at once. "
"This flag disables that.",
)
fp = parser.add_mutually_exclusive_group(required=False) fp = parser.add_mutually_exclusive_group(required=False)
fp.add_argument('--validation', dest='validation', action='store_true') fp.add_argument("--validation", dest="validation", action="store_true")
fp.add_argument('--testing', dest='validation', action='store_false') fp.add_argument("--testing", dest="validation", action="store_false")
parser.set_defaults(validation=True) parser.set_defaults(validation=True)
args = parser.parse_args() args = parser.parse_args()
......
...@@ -4,10 +4,12 @@ from collections import defaultdict ...@@ -4,10 +4,12 @@ from collections import defaultdict
import torch as th import torch as th
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import tqdm
import dgl import dgl
import dgl.function as fn import dgl.function as fn
import dgl.nn as dglnn import dgl.nn as dglnn
import tqdm
class RelGraphConvLayer(nn.Module): class RelGraphConvLayer(nn.Module):
r"""Relational graph convolution layer. r"""Relational graph convolution layer.
...@@ -33,17 +35,20 @@ class RelGraphConvLayer(nn.Module): ...@@ -33,17 +35,20 @@ class RelGraphConvLayer(nn.Module):
dropout : float, optional dropout : float, optional
Dropout rate. Default: 0.0 Dropout rate. Default: 0.0
""" """
def __init__(self,
in_feat, def __init__(
out_feat, self,
rel_names, in_feat,
num_bases, out_feat,
*, rel_names,
weight=True, num_bases,
bias=True, *,
activation=None, weight=True,
self_loop=False, bias=True,
dropout=0.0): activation=None,
self_loop=False,
dropout=0.0
):
super(RelGraphConvLayer, self).__init__() super(RelGraphConvLayer, self).__init__()
self.in_feat = in_feat self.in_feat = in_feat
self.out_feat = out_feat self.out_feat = out_feat
...@@ -53,19 +58,29 @@ class RelGraphConvLayer(nn.Module): ...@@ -53,19 +58,29 @@ class RelGraphConvLayer(nn.Module):
self.activation = activation self.activation = activation
self.self_loop = self_loop self.self_loop = self_loop
self.conv = dglnn.HeteroGraphConv({ self.conv = dglnn.HeteroGraphConv(
rel : dglnn.GraphConv(in_feat, out_feat, norm='right', weight=False, bias=False) {
rel: dglnn.GraphConv(
in_feat, out_feat, norm="right", weight=False, bias=False
)
for rel in rel_names for rel in rel_names
}) }
)
self.use_weight = weight self.use_weight = weight
self.use_basis = num_bases < len(self.rel_names) and weight self.use_basis = num_bases < len(self.rel_names) and weight
if self.use_weight: if self.use_weight:
if self.use_basis: if self.use_basis:
self.basis = dglnn.WeightBasis((in_feat, out_feat), num_bases, len(self.rel_names)) self.basis = dglnn.WeightBasis(
(in_feat, out_feat), num_bases, len(self.rel_names)
)
else: else:
self.weight = nn.Parameter(th.Tensor(len(self.rel_names), in_feat, out_feat)) self.weight = nn.Parameter(
nn.init.xavier_uniform_(self.weight, gain=nn.init.calculate_gain('relu')) th.Tensor(len(self.rel_names), in_feat, out_feat)
)
nn.init.xavier_uniform_(
self.weight, gain=nn.init.calculate_gain("relu")
)
# bias # bias
if bias: if bias:
...@@ -75,8 +90,9 @@ class RelGraphConvLayer(nn.Module): ...@@ -75,8 +90,9 @@ class RelGraphConvLayer(nn.Module):
# weight for self loop # weight for self loop
if self.self_loop: if self.self_loop:
self.loop_weight = nn.Parameter(th.Tensor(in_feat, out_feat)) self.loop_weight = nn.Parameter(th.Tensor(in_feat, out_feat))
nn.init.xavier_uniform_(self.loop_weight, nn.init.xavier_uniform_(
gain=nn.init.calculate_gain('relu')) self.loop_weight, gain=nn.init.calculate_gain("relu")
)
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
...@@ -98,14 +114,18 @@ class RelGraphConvLayer(nn.Module): ...@@ -98,14 +114,18 @@ class RelGraphConvLayer(nn.Module):
g = g.local_var() g = g.local_var()
if self.use_weight: if self.use_weight:
weight = self.basis() if self.use_basis else self.weight weight = self.basis() if self.use_basis else self.weight
wdict = {self.rel_names[i] : {'weight' : w.squeeze(0)} wdict = {
for i, w in enumerate(th.split(weight, 1, dim=0))} self.rel_names[i]: {"weight": w.squeeze(0)}
for i, w in enumerate(th.split(weight, 1, dim=0))
}
else: else:
wdict = {} wdict = {}
if g.is_block: if g.is_block:
inputs_src = inputs inputs_src = inputs
inputs_dst = {k: v[:g.number_of_dst_nodes(k)] for k, v in inputs.items()} inputs_dst = {
k: v[: g.number_of_dst_nodes(k)] for k, v in inputs.items()
}
else: else:
inputs_src = inputs_dst = inputs inputs_src = inputs_dst = inputs
...@@ -119,7 +139,9 @@ class RelGraphConvLayer(nn.Module): ...@@ -119,7 +139,9 @@ class RelGraphConvLayer(nn.Module):
if self.activation: if self.activation:
h = self.activation(h) h = self.activation(h)
return self.dropout(h) return self.dropout(h)
return {ntype : _apply(ntype, h) for ntype, h in hs.items()}
return {ntype: _apply(ntype, h) for ntype, h in hs.items()}
class RelGraphConvLayerHeteroAPI(nn.Module): class RelGraphConvLayerHeteroAPI(nn.Module):
r"""Relational graph convolution layer. r"""Relational graph convolution layer.
...@@ -145,17 +167,20 @@ class RelGraphConvLayerHeteroAPI(nn.Module): ...@@ -145,17 +167,20 @@ class RelGraphConvLayerHeteroAPI(nn.Module):
dropout : float, optional dropout : float, optional
Dropout rate. Default: 0.0 Dropout rate. Default: 0.0
""" """
def __init__(self,
in_feat, def __init__(
out_feat, self,
rel_names, in_feat,
num_bases, out_feat,
*, rel_names,
weight=True, num_bases,
bias=True, *,
activation=None, weight=True,
self_loop=False, bias=True,
dropout=0.0): activation=None,
self_loop=False,
dropout=0.0
):
super(RelGraphConvLayerHeteroAPI, self).__init__() super(RelGraphConvLayerHeteroAPI, self).__init__()
self.in_feat = in_feat self.in_feat = in_feat
self.out_feat = out_feat self.out_feat = out_feat
...@@ -169,10 +194,16 @@ class RelGraphConvLayerHeteroAPI(nn.Module): ...@@ -169,10 +194,16 @@ class RelGraphConvLayerHeteroAPI(nn.Module):
self.use_basis = num_bases < len(self.rel_names) and weight self.use_basis = num_bases < len(self.rel_names) and weight
if self.use_weight: if self.use_weight:
if self.use_basis: if self.use_basis:
self.basis = dglnn.WeightBasis((in_feat, out_feat), num_bases, len(self.rel_names)) self.basis = dglnn.WeightBasis(
(in_feat, out_feat), num_bases, len(self.rel_names)
)
else: else:
self.weight = nn.Parameter(th.Tensor(len(self.rel_names), in_feat, out_feat)) self.weight = nn.Parameter(
nn.init.xavier_uniform_(self.weight, gain=nn.init.calculate_gain('relu')) th.Tensor(len(self.rel_names), in_feat, out_feat)
)
nn.init.xavier_uniform_(
self.weight, gain=nn.init.calculate_gain("relu")
)
# bias # bias
if bias: if bias:
...@@ -182,8 +213,9 @@ class RelGraphConvLayerHeteroAPI(nn.Module): ...@@ -182,8 +213,9 @@ class RelGraphConvLayerHeteroAPI(nn.Module):
# weight for self loop # weight for self loop
if self.self_loop: if self.self_loop:
self.loop_weight = nn.Parameter(th.Tensor(in_feat, out_feat)) self.loop_weight = nn.Parameter(th.Tensor(in_feat, out_feat))
nn.init.xavier_uniform_(self.loop_weight, nn.init.xavier_uniform_(
gain=nn.init.calculate_gain('relu')) self.loop_weight, gain=nn.init.calculate_gain("relu")
)
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
...@@ -205,29 +237,33 @@ class RelGraphConvLayerHeteroAPI(nn.Module): ...@@ -205,29 +237,33 @@ class RelGraphConvLayerHeteroAPI(nn.Module):
g = g.local_var() g = g.local_var()
if self.use_weight: if self.use_weight:
weight = self.basis() if self.use_basis else self.weight weight = self.basis() if self.use_basis else self.weight
wdict = {self.rel_names[i] : {'weight' : w.squeeze(0)} wdict = {
for i, w in enumerate(th.split(weight, 1, dim=0))} self.rel_names[i]: {"weight": w.squeeze(0)}
for i, w in enumerate(th.split(weight, 1, dim=0))
}
else: else:
wdict = {} wdict = {}
inputs_src = inputs_dst = inputs inputs_src = inputs_dst = inputs
for srctype,_,_ in g.canonical_etypes: for srctype, _, _ in g.canonical_etypes:
g.nodes[srctype].data['h'] = inputs[srctype] g.nodes[srctype].data["h"] = inputs[srctype]
if self.use_weight: if self.use_weight:
g.apply_edges(fn.copy_u('h', 'm')) g.apply_edges(fn.copy_u("h", "m"))
m = g.edata['m'] m = g.edata["m"]
for rel in g.canonical_etypes: for rel in g.canonical_etypes:
_, etype, _ = rel _, etype, _ = rel
g.edges[rel].data['h*w_r'] = th.matmul(m[rel], wdict[etype]['weight']) g.edges[rel].data["h*w_r"] = th.matmul(
m[rel], wdict[etype]["weight"]
)
else: else:
g.apply_edges(fn.copy_u('h', 'h*w_r')) g.apply_edges(fn.copy_u("h", "h*w_r"))
g.update_all(fn.copy_e('h*w_r', 'm'), fn.sum('m', 'h')) g.update_all(fn.copy_e("h*w_r", "m"), fn.sum("m", "h"))
def _apply(ntype): def _apply(ntype):
h = g.nodes[ntype].data['h'] h = g.nodes[ntype].data["h"]
if self.self_loop: if self.self_loop:
h = h + th.matmul(inputs_dst[ntype], self.loop_weight) h = h + th.matmul(inputs_dst[ntype], self.loop_weight)
if self.bias: if self.bias:
...@@ -235,16 +271,16 @@ class RelGraphConvLayerHeteroAPI(nn.Module): ...@@ -235,16 +271,16 @@ class RelGraphConvLayerHeteroAPI(nn.Module):
if self.activation: if self.activation:
h = self.activation(h) h = self.activation(h)
return self.dropout(h) return self.dropout(h)
return {ntype : _apply(ntype) for ntype in g.dsttypes}
return {ntype: _apply(ntype) for ntype in g.dsttypes}
class RelGraphEmbed(nn.Module): class RelGraphEmbed(nn.Module):
r"""Embedding layer for featureless heterograph.""" r"""Embedding layer for featureless heterograph."""
def __init__(self,
g, def __init__(
embed_size, self, g, embed_size, embed_name="embed", activation=None, dropout=0.0
embed_name='embed', ):
activation=None,
dropout=0.0):
super(RelGraphEmbed, self).__init__() super(RelGraphEmbed, self).__init__()
self.g = g self.g = g
self.embed_size = embed_size self.embed_size = embed_size
...@@ -255,8 +291,10 @@ class RelGraphEmbed(nn.Module): ...@@ -255,8 +291,10 @@ class RelGraphEmbed(nn.Module):
# create weight embeddings for each node for each relation # create weight embeddings for each node for each relation
self.embeds = nn.ParameterDict() self.embeds = nn.ParameterDict()
for ntype in g.ntypes: for ntype in g.ntypes:
embed = nn.Parameter(th.Tensor(g.number_of_nodes(ntype), self.embed_size)) embed = nn.Parameter(
nn.init.xavier_uniform_(embed, gain=nn.init.calculate_gain('relu')) th.Tensor(g.number_of_nodes(ntype), self.embed_size)
)
nn.init.xavier_uniform_(embed, gain=nn.init.calculate_gain("relu"))
self.embeds[ntype] = embed self.embeds[ntype] = embed
def forward(self, block=None): def forward(self, block=None):
...@@ -276,14 +314,18 @@ class RelGraphEmbed(nn.Module): ...@@ -276,14 +314,18 @@ class RelGraphEmbed(nn.Module):
""" """
return self.embeds return self.embeds
class EntityClassify(nn.Module): class EntityClassify(nn.Module):
def __init__(self, def __init__(
g, self,
h_dim, out_dim, g,
num_bases, h_dim,
num_hidden_layers=1, out_dim,
dropout=0, num_bases,
use_self_loop=False): num_hidden_layers=1,
dropout=0,
use_self_loop=False,
):
super(EntityClassify, self).__init__() super(EntityClassify, self).__init__()
self.g = g self.g = g
self.h_dim = h_dim self.h_dim = h_dim
...@@ -301,21 +343,42 @@ class EntityClassify(nn.Module): ...@@ -301,21 +343,42 @@ class EntityClassify(nn.Module):
self.embed_layer = RelGraphEmbed(g, self.h_dim) self.embed_layer = RelGraphEmbed(g, self.h_dim)
self.layers = nn.ModuleList() self.layers = nn.ModuleList()
# i2h # i2h
self.layers.append(RelGraphConvLayer( self.layers.append(
self.h_dim, self.h_dim, self.rel_names, RelGraphConvLayer(
self.num_bases, activation=F.relu, self_loop=self.use_self_loop, self.h_dim,
dropout=self.dropout, weight=False)) self.h_dim,
self.rel_names,
self.num_bases,
activation=F.relu,
self_loop=self.use_self_loop,
dropout=self.dropout,
weight=False,
)
)
# h2h # h2h
for i in range(self.num_hidden_layers): for i in range(self.num_hidden_layers):
self.layers.append(RelGraphConvLayer( self.layers.append(
self.h_dim, self.h_dim, self.rel_names, RelGraphConvLayer(
self.num_bases, activation=F.relu, self_loop=self.use_self_loop, self.h_dim,
dropout=self.dropout)) self.h_dim,
self.rel_names,
self.num_bases,
activation=F.relu,
self_loop=self.use_self_loop,
dropout=self.dropout,
)
)
# h2o # h2o
self.layers.append(RelGraphConvLayer( self.layers.append(
self.h_dim, self.out_dim, self.rel_names, RelGraphConvLayer(
self.num_bases, activation=None, self.h_dim,
self_loop=self.use_self_loop)) self.out_dim,
self.rel_names,
self.num_bases,
activation=None,
self_loop=self.use_self_loop,
)
)
def forward(self, h=None, blocks=None): def forward(self, h=None, blocks=None):
if h is None: if h is None:
...@@ -346,8 +409,10 @@ class EntityClassify(nn.Module): ...@@ -346,8 +409,10 @@ class EntityClassify(nn.Module):
y = { y = {
k: th.zeros( k: th.zeros(
g.number_of_nodes(k), g.number_of_nodes(k),
self.h_dim if l != len(self.layers) - 1 else self.out_dim) self.h_dim if l != len(self.layers) - 1 else self.out_dim,
for k in g.ntypes} )
for k in g.ntypes
}
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1) sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)
dataloader = dgl.dataloading.DataLoader( dataloader = dgl.dataloading.DataLoader(
...@@ -357,12 +422,16 @@ class EntityClassify(nn.Module): ...@@ -357,12 +422,16 @@ class EntityClassify(nn.Module):
batch_size=batch_size, batch_size=batch_size,
shuffle=True, shuffle=True,
drop_last=False, drop_last=False,
num_workers=num_workers) num_workers=num_workers,
)
for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader): for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):
block = blocks[0].to(device) block = blocks[0].to(device)
h = {k: x[k][input_nodes[k]].to(device) for k in input_nodes.keys()} h = {
k: x[k][input_nodes[k]].to(device)
for k in input_nodes.keys()
}
h = layer(block, h) h = layer(block, h)
for k in output_nodes.keys(): for k in output_nodes.keys():
...@@ -371,14 +440,18 @@ class EntityClassify(nn.Module): ...@@ -371,14 +440,18 @@ class EntityClassify(nn.Module):
x = y x = y
return y return y
class EntityClassify_HeteroAPI(nn.Module): class EntityClassify_HeteroAPI(nn.Module):
def __init__(self, def __init__(
g, self,
h_dim, out_dim, g,
num_bases, h_dim,
num_hidden_layers=1, out_dim,
dropout=0, num_bases,
use_self_loop=False): num_hidden_layers=1,
dropout=0,
use_self_loop=False,
):
super(EntityClassify_HeteroAPI, self).__init__() super(EntityClassify_HeteroAPI, self).__init__()
self.g = g self.g = g
self.h_dim = h_dim self.h_dim = h_dim
...@@ -396,21 +469,42 @@ class EntityClassify_HeteroAPI(nn.Module): ...@@ -396,21 +469,42 @@ class EntityClassify_HeteroAPI(nn.Module):
self.embed_layer = RelGraphEmbed(g, self.h_dim) self.embed_layer = RelGraphEmbed(g, self.h_dim)
self.layers = nn.ModuleList() self.layers = nn.ModuleList()
# i2h # i2h
self.layers.append(RelGraphConvLayerHeteroAPI( self.layers.append(
self.h_dim, self.h_dim, self.rel_names, RelGraphConvLayerHeteroAPI(
self.num_bases, activation=F.relu, self_loop=self.use_self_loop, self.h_dim,
dropout=self.dropout, weight=False)) self.h_dim,
self.rel_names,
self.num_bases,
activation=F.relu,
self_loop=self.use_self_loop,
dropout=self.dropout,
weight=False,
)
)
# h2h # h2h
for i in range(self.num_hidden_layers): for i in range(self.num_hidden_layers):
self.layers.append(RelGraphConvLayerHeteroAPI( self.layers.append(
self.h_dim, self.h_dim, self.rel_names, RelGraphConvLayerHeteroAPI(
self.num_bases, activation=F.relu, self_loop=self.use_self_loop, self.h_dim,
dropout=self.dropout)) self.h_dim,
self.rel_names,
self.num_bases,
activation=F.relu,
self_loop=self.use_self_loop,
dropout=self.dropout,
)
)
# h2o # h2o
self.layers.append(RelGraphConvLayerHeteroAPI( self.layers.append(
self.h_dim, self.out_dim, self.rel_names, RelGraphConvLayerHeteroAPI(
self.num_bases, activation=None, self.h_dim,
self_loop=self.use_self_loop)) self.out_dim,
self.rel_names,
self.num_bases,
activation=None,
self_loop=self.use_self_loop,
)
)
def forward(self, h=None, blocks=None): def forward(self, h=None, blocks=None):
if h is None: if h is None:
...@@ -441,8 +535,10 @@ class EntityClassify_HeteroAPI(nn.Module): ...@@ -441,8 +535,10 @@ class EntityClassify_HeteroAPI(nn.Module):
y = { y = {
k: th.zeros( k: th.zeros(
g.number_of_nodes(k), g.number_of_nodes(k),
self.h_dim if l != len(self.layers) - 1 else self.out_dim) self.h_dim if l != len(self.layers) - 1 else self.out_dim,
for k in g.ntypes} )
for k in g.ntypes
}
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1) sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)
dataloader = dgl.dataloading.DataLoader( dataloader = dgl.dataloading.DataLoader(
...@@ -452,12 +548,16 @@ class EntityClassify_HeteroAPI(nn.Module): ...@@ -452,12 +548,16 @@ class EntityClassify_HeteroAPI(nn.Module):
batch_size=batch_size, batch_size=batch_size,
shuffle=True, shuffle=True,
drop_last=False, drop_last=False,
num_workers=num_workers) num_workers=num_workers,
)
for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader): for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):
block = blocks[0].to(device) block = blocks[0].to(device)
h = {k: x[k][input_nodes[k]].to(device) for k in input_nodes.keys()} h = {
k: x[k][input_nodes[k]].to(device)
for k in input_nodes.keys()
}
h = layer(block, h) h = layer(block, h)
for k in h.keys(): for k in h.keys():
......
"""Infering Relational Data with Graph Convolutional Networks """Infering Relational Data with Graph Convolutional Networks
""" """
import argparse import argparse
import torch as th
from functools import partial from functools import partial
import torch.nn.functional as F
from dgl.data.rdf import AIFB, MUTAG, BGS, AM import torch as th
import torch.nn.functional as F
from entity_classify import EntityClassify from entity_classify import EntityClassify
from dgl.data.rdf import AIFB, AM, BGS, MUTAG
def main(args): def main(args):
# load graph data # load graph data
if args.dataset == 'aifb': if args.dataset == "aifb":
dataset = AIFBDataset() dataset = AIFBDataset()
elif args.dataset == 'mutag': elif args.dataset == "mutag":
dataset = MUTAGDataset() dataset = MUTAGDataset()
elif args.dataset == 'bgs': elif args.dataset == "bgs":
dataset = BGSDataset() dataset = BGSDataset()
elif args.dataset == 'am': elif args.dataset == "am":
dataset = AMDataset() dataset = AMDataset()
else: else:
raise ValueError() raise ValueError()
...@@ -24,9 +26,9 @@ def main(args): ...@@ -24,9 +26,9 @@ def main(args):
g = dataset[0] g = dataset[0]
category = dataset.predict_category category = dataset.predict_category
num_classes = dataset.num_classes num_classes = dataset.num_classes
test_mask = g.nodes[category].data.pop('test_mask') test_mask = g.nodes[category].data.pop("test_mask")
test_idx = th.nonzero(test_mask, as_tuple=False).squeeze() test_idx = th.nonzero(test_mask, as_tuple=False).squeeze()
labels = g.nodes[category].data.pop('labels') labels = g.nodes[category].data.pop("labels")
# check cuda # check cuda
use_cuda = args.gpu >= 0 and th.cuda.is_available() use_cuda = args.gpu >= 0 and th.cuda.is_available()
...@@ -34,15 +36,17 @@ def main(args): ...@@ -34,15 +36,17 @@ def main(args):
th.cuda.set_device(args.gpu) th.cuda.set_device(args.gpu)
labels = labels.cuda() labels = labels.cuda()
test_idx = test_idx.cuda() test_idx = test_idx.cuda()
g = g.to('cuda:%d' % args.gpu) g = g.to("cuda:%d" % args.gpu)
# create model # create model
model = EntityClassify(g, model = EntityClassify(
args.n_hidden, g,
num_classes, args.n_hidden,
num_bases=args.n_bases, num_classes,
num_hidden_layers=args.n_layers - 2, num_bases=args.n_bases,
use_self_loop=args.use_self_loop) num_hidden_layers=args.n_layers - 2,
use_self_loop=args.use_self_loop,
)
model.load_state_dict(th.load(args.model_path)) model.load_state_dict(th.load(args.model_path))
if use_cuda: if use_cuda:
model.cuda() model.cuda()
...@@ -51,28 +55,45 @@ def main(args): ...@@ -51,28 +55,45 @@ def main(args):
model.eval() model.eval()
logits = model.forward()[category] logits = model.forward()[category]
test_loss = F.cross_entropy(logits[test_idx], labels[test_idx]) test_loss = F.cross_entropy(logits[test_idx], labels[test_idx])
test_acc = th.sum(logits[test_idx].argmax(dim=1) == labels[test_idx]).item() / len(test_idx) test_acc = th.sum(
print("Test Acc: {:.4f} | Test loss: {:.4f}".format(test_acc, test_loss.item())) logits[test_idx].argmax(dim=1) == labels[test_idx]
).item() / len(test_idx)
print(
"Test Acc: {:.4f} | Test loss: {:.4f}".format(
test_acc, test_loss.item()
)
)
print() print()
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='RGCN') if __name__ == "__main__":
parser.add_argument("--n-hidden", type=int, default=16, parser = argparse.ArgumentParser(description="RGCN")
help="number of hidden units") parser.add_argument(
parser.add_argument("--gpu", type=int, default=-1, "--n-hidden", type=int, default=16, help="number of hidden units"
help="gpu") )
parser.add_argument("--lr", type=float, default=1e-2, parser.add_argument("--gpu", type=int, default=-1, help="gpu")
help="learning rate") parser.add_argument("--lr", type=float, default=1e-2, help="learning rate")
parser.add_argument("--n-bases", type=int, default=-1, parser.add_argument(
help="number of filter weight matrices, default: -1 [use all]") "--n-bases",
parser.add_argument("--n-layers", type=int, default=2, type=int,
help="number of propagation rounds") default=-1,
parser.add_argument("-d", "--dataset", type=str, required=True, help="number of filter weight matrices, default: -1 [use all]",
help="dataset to use") )
parser.add_argument("--model_path", type=str, parser.add_argument(
help='path of the model to load from') "--n-layers", type=int, default=2, help="number of propagation rounds"
parser.add_argument("--use-self-loop", default=False, action='store_true', )
help="include self feature as a special relation") parser.add_argument(
"-d", "--dataset", type=str, required=True, help="dataset to use"
)
parser.add_argument(
"--model_path", type=str, help="path of the model to load from"
)
parser.add_argument(
"--use-self-loop",
default=False,
action="store_true",
help="include self feature as a special relation",
)
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
......
...@@ -9,6 +9,7 @@ References: ...@@ -9,6 +9,7 @@ References:
import torch import torch
from torch import nn from torch import nn
import dgl.function as fn import dgl.function as fn
...@@ -21,26 +22,23 @@ class RRNLayer(nn.Module): ...@@ -21,26 +22,23 @@ class RRNLayer(nn.Module):
def forward(self, g): def forward(self, g):
g.apply_edges(self.get_msg) g.apply_edges(self.get_msg)
g.edata['e'] = self.edge_dropout(g.edata['e']) g.edata["e"] = self.edge_dropout(g.edata["e"])
g.update_all(message_func=fn.copy_e('e', 'msg'), g.update_all(
reduce_func=fn.sum('msg', 'm')) message_func=fn.copy_e("e", "msg"), reduce_func=fn.sum("msg", "m")
)
g.apply_nodes(self.node_update) g.apply_nodes(self.node_update)
def get_msg(self, edges): def get_msg(self, edges):
e = torch.cat([edges.src['h'], edges.dst['h']], -1) e = torch.cat([edges.src["h"], edges.dst["h"]], -1)
e = self.msg_layer(e) e = self.msg_layer(e)
return {'e': e} return {"e": e}
def node_update(self, nodes): def node_update(self, nodes):
return self.node_update_func(nodes) return self.node_update_func(nodes)
class RRN(nn.Module): class RRN(nn.Module):
def __init__(self, def __init__(self, msg_layer, node_update_func, num_steps, edge_drop):
msg_layer,
node_update_func,
num_steps,
edge_drop):
super(RRN, self).__init__() super(RRN, self).__init__()
self.num_steps = num_steps self.num_steps = num_steps
self.rrn_layer = RRNLayer(msg_layer, node_update_func, edge_drop) self.rrn_layer = RRNLayer(msg_layer, node_update_func, edge_drop)
...@@ -50,11 +48,9 @@ class RRN(nn.Module): ...@@ -50,11 +48,9 @@ class RRN(nn.Module):
for _ in range(self.num_steps): for _ in range(self.num_steps):
self.rrn_layer(g) self.rrn_layer(g)
if get_all_outputs: if get_all_outputs:
outputs.append(g.ndata['h']) outputs.append(g.ndata["h"])
if get_all_outputs: if get_all_outputs:
outputs = torch.stack(outputs, 0) # num_steps x n_nodes x h_dim outputs = torch.stack(outputs, 0) # num_steps x n_nodes x h_dim
else: else:
outputs = g.ndata['h'] # n_nodes x h_dim outputs = g.ndata["h"] # n_nodes x h_dim
return outputs return outputs
...@@ -2,17 +2,13 @@ ...@@ -2,17 +2,13 @@
SudokuNN module based on RRN for solving sudoku puzzles SudokuNN module based on RRN for solving sudoku puzzles
""" """
import torch
from rrn import RRN from rrn import RRN
from torch import nn from torch import nn
import torch
class SudokuNN(nn.Module): class SudokuNN(nn.Module):
def __init__(self, def __init__(self, num_steps, embed_size=16, hidden_dim=96, edge_drop=0.1):
num_steps,
embed_size=16,
hidden_dim=96,
edge_drop=0.1):
super(SudokuNN, self).__init__() super(SudokuNN, self).__init__()
self.num_steps = num_steps self.num_steps = num_steps
...@@ -21,7 +17,7 @@ class SudokuNN(nn.Module): ...@@ -21,7 +17,7 @@ class SudokuNN(nn.Module):
self.col_embed = nn.Embedding(9, embed_size) self.col_embed = nn.Embedding(9, embed_size)
self.input_layer = nn.Sequential( self.input_layer = nn.Sequential(
nn.Linear(3*embed_size, hidden_dim), nn.Linear(3 * embed_size, hidden_dim),
nn.ReLU(), nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim), nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(), nn.ReLU(),
...@@ -30,10 +26,10 @@ class SudokuNN(nn.Module): ...@@ -30,10 +26,10 @@ class SudokuNN(nn.Module):
nn.Linear(hidden_dim, hidden_dim), nn.Linear(hidden_dim, hidden_dim),
) )
self.lstm = nn.LSTMCell(hidden_dim*2, hidden_dim, bias=False) self.lstm = nn.LSTMCell(hidden_dim * 2, hidden_dim, bias=False)
msg_layer = nn.Sequential( msg_layer = nn.Sequential(
nn.Linear(2*hidden_dim, hidden_dim), nn.Linear(2 * hidden_dim, hidden_dim),
nn.ReLU(), nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim), nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(), nn.ReLU(),
...@@ -49,17 +45,17 @@ class SudokuNN(nn.Module): ...@@ -49,17 +45,17 @@ class SudokuNN(nn.Module):
self.loss_func = nn.CrossEntropyLoss() self.loss_func = nn.CrossEntropyLoss()
def forward(self, g, is_training=True): def forward(self, g, is_training=True):
labels = g.ndata.pop('a') labels = g.ndata.pop("a")
input_digits = self.digit_embed(g.ndata.pop('q')) input_digits = self.digit_embed(g.ndata.pop("q"))
rows = self.row_embed(g.ndata.pop('row')) rows = self.row_embed(g.ndata.pop("row"))
cols = self.col_embed(g.ndata.pop('col')) cols = self.col_embed(g.ndata.pop("col"))
x = self.input_layer(torch.cat([input_digits, rows, cols], -1)) x = self.input_layer(torch.cat([input_digits, rows, cols], -1))
g.ndata['x'] = x g.ndata["x"] = x
g.ndata['h'] = x g.ndata["h"] = x
g.ndata['rnn_h'] = torch.zeros_like(x, dtype=torch.float) g.ndata["rnn_h"] = torch.zeros_like(x, dtype=torch.float)
g.ndata['rnn_c'] = torch.zeros_like(x, dtype=torch.float) g.ndata["rnn_c"] = torch.zeros_like(x, dtype=torch.float)
outputs = self.rrn(g, is_training) outputs = self.rrn(g, is_training)
logits = self.output_layer(outputs) logits = self.output_layer(outputs)
...@@ -67,15 +63,18 @@ class SudokuNN(nn.Module): ...@@ -67,15 +63,18 @@ class SudokuNN(nn.Module):
preds = torch.argmax(logits, -1) preds = torch.argmax(logits, -1)
if is_training: if is_training:
labels = torch.stack([labels]*self.num_steps, 0) labels = torch.stack([labels] * self.num_steps, 0)
logits = logits.view([-1, 10]) logits = logits.view([-1, 10])
labels = labels.view([-1]) labels = labels.view([-1])
loss = self.loss_func(logits, labels) loss = self.loss_func(logits, labels)
return preds, loss return preds, loss
def node_update_func(self, nodes): def node_update_func(self, nodes):
x, h, m, c = nodes.data['x'], nodes.data['rnn_h'], nodes.data['m'], nodes.data['rnn_c'] x, h, m, c = (
nodes.data["x"],
nodes.data["rnn_h"],
nodes.data["m"],
nodes.data["rnn_c"],
)
new_h, new_c = self.lstm(torch.cat([x, m], -1), (h, c)) new_h, new_c = self.lstm(torch.cat([x, m], -1), (h, c))
return {'h': new_h, 'rnn_c': new_c, 'rnn_h': new_h} return {"h": new_h, "rnn_c": new_c, "rnn_h": new_h}
...@@ -2,24 +2,28 @@ import csv ...@@ -2,24 +2,28 @@ import csv
import os import os
import urllib.request import urllib.request
import zipfile import zipfile
from copy import copy
import numpy as np import numpy as np
import torch
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.dataset import Dataset from torch.utils.data.dataset import Dataset
import torch
import dgl import dgl
from copy import copy
def _basic_sudoku_graph(): def _basic_sudoku_graph():
grids = [[0, 1, 2, 9, 10, 11, 18, 19, 20], grids = [
[3, 4, 5, 12, 13, 14, 21, 22, 23], [0, 1, 2, 9, 10, 11, 18, 19, 20],
[6, 7, 8, 15, 16, 17, 24, 25, 26], [3, 4, 5, 12, 13, 14, 21, 22, 23],
[27, 28, 29, 36, 37, 38, 45, 46, 47], [6, 7, 8, 15, 16, 17, 24, 25, 26],
[30, 31, 32, 39, 40, 41, 48, 49, 50], [27, 28, 29, 36, 37, 38, 45, 46, 47],
[33, 34, 35, 42, 43, 44, 51, 52, 53], [30, 31, 32, 39, 40, 41, 48, 49, 50],
[54, 55, 56, 63, 64, 65, 72, 73, 74], [33, 34, 35, 42, 43, 44, 51, 52, 53],
[57, 58, 59, 66, 67, 68, 75, 76, 77], [54, 55, 56, 63, 64, 65, 72, 73, 74],
[60, 61, 62, 69, 70, 71, 78, 79, 80]] [57, 58, 59, 66, 67, 68, 75, 76, 77],
[60, 61, 62, 69, 70, 71, 78, 79, 80],
]
edges = set() edges = set()
for i in range(81): for i in range(81):
row, col = i // 9, i % 9 row, col = i // 9, i % 9
...@@ -33,7 +37,7 @@ def _basic_sudoku_graph(): ...@@ -33,7 +37,7 @@ def _basic_sudoku_graph():
col_src += 9 col_src += 9
# same grid # same grid
grid_row, grid_col = row // 3, col // 3 grid_row, grid_col = row // 3, col // 3
for n in grids[grid_row*3 + grid_col]: for n in grids[grid_row * 3 + grid_col]:
if n != i: if n != i:
edges.add((n, i)) edges.add((n, i))
edges = list(edges) edges = list(edges)
...@@ -53,26 +57,26 @@ class ListDataset(Dataset): ...@@ -53,26 +57,26 @@ class ListDataset(Dataset):
return len(self.lists_of_data[0]) return len(self.lists_of_data[0])
def _get_sudoku_dataset(segment='train'): def _get_sudoku_dataset(segment="train"):
assert segment in ['train', 'valid', 'test'] assert segment in ["train", "valid", "test"]
url = "https://data.dgl.ai/dataset/sudoku-hard.zip" url = "https://data.dgl.ai/dataset/sudoku-hard.zip"
zip_fname = "/tmp/sudoku-hard.zip" zip_fname = "/tmp/sudoku-hard.zip"
dest_dir = '/tmp/sudoku-hard/' dest_dir = "/tmp/sudoku-hard/"
if not os.path.exists(dest_dir): if not os.path.exists(dest_dir):
print("Downloading data...") print("Downloading data...")
urllib.request.urlretrieve(url, zip_fname) urllib.request.urlretrieve(url, zip_fname)
with zipfile.ZipFile(zip_fname) as f: with zipfile.ZipFile(zip_fname) as f:
f.extractall('/tmp/') f.extractall("/tmp/")
def read_csv(fname): def read_csv(fname):
print("Reading %s..." % fname) print("Reading %s..." % fname)
with open(dest_dir + fname) as f: with open(dest_dir + fname) as f:
reader = csv.reader(f, delimiter=',') reader = csv.reader(f, delimiter=",")
return [(q, a) for q, a in reader] return [(q, a) for q, a in reader]
data = read_csv(segment + '.csv') data = read_csv(segment + ".csv")
def encode(samples): def encode(samples):
def parse(x): def parse(x):
...@@ -82,12 +86,12 @@ def _get_sudoku_dataset(segment='train'): ...@@ -82,12 +86,12 @@ def _get_sudoku_dataset(segment='train'):
return encoded return encoded
data = encode(data) data = encode(data)
print(f'Number of puzzles in {segment} set : {len(data)}') print(f"Number of puzzles in {segment} set : {len(data)}")
return data return data
def sudoku_dataloader(batch_size, segment='train'): def sudoku_dataloader(batch_size, segment="train"):
""" """
Get a DataLoader instance for dataset of sudoku. Every iteration of the dataloader returns Get a DataLoader instance for dataset of sudoku. Every iteration of the dataloader returns
a DGLGraph instance, the ndata of the graph contains: a DGLGraph instance, the ndata of the graph contains:
...@@ -104,7 +108,7 @@ def sudoku_dataloader(batch_size, segment='train'): ...@@ -104,7 +108,7 @@ def sudoku_dataloader(batch_size, segment='train'):
q, a = zip(*data) q, a = zip(*data)
dataset = ListDataset(q, a) dataset = ListDataset(q, a)
if segment == 'train': if segment == "train":
data_sampler = RandomSampler(dataset) data_sampler = RandomSampler(dataset)
else: else:
data_sampler = SequentialSampler(dataset) data_sampler = SequentialSampler(dataset)
...@@ -120,13 +124,15 @@ def sudoku_dataloader(batch_size, segment='train'): ...@@ -120,13 +124,15 @@ def sudoku_dataloader(batch_size, segment='train'):
q = torch.tensor(q, dtype=torch.long) q = torch.tensor(q, dtype=torch.long)
a = torch.tensor(a, dtype=torch.long) a = torch.tensor(a, dtype=torch.long)
graph = copy(basic_graph) graph = copy(basic_graph)
graph.ndata['q'] = q # q means question graph.ndata["q"] = q # q means question
graph.ndata['a'] = a # a means answer graph.ndata["a"] = a # a means answer
graph.ndata['row'] = torch.tensor(rows, dtype=torch.long) graph.ndata["row"] = torch.tensor(rows, dtype=torch.long)
graph.ndata['col'] = torch.tensor(cols, dtype=torch.long) graph.ndata["col"] = torch.tensor(cols, dtype=torch.long)
graph_list.append(graph) graph_list.append(graph)
batch_graph = dgl.batch(graph_list) batch_graph = dgl.batch(graph_list)
return batch_graph return batch_graph
dataloader = DataLoader(dataset, batch_size, sampler=data_sampler, collate_fn=collate_fn) dataloader = DataLoader(
dataset, batch_size, sampler=data_sampler, collate_fn=collate_fn
)
return dataloader return dataloader
import os import os
import urllib.request import urllib.request
import torch
import numpy as np import numpy as np
from sudoku_data import _basic_sudoku_graph import torch
from sudoku import SudokuNN from sudoku import SudokuNN
from sudoku_data import _basic_sudoku_graph
def solve_sudoku(puzzle): def solve_sudoku(puzzle):
...@@ -14,18 +14,18 @@ def solve_sudoku(puzzle): ...@@ -14,18 +14,18 @@ def solve_sudoku(puzzle):
:return: a [9, 9] shaped numpy array :return: a [9, 9] shaped numpy array
""" """
puzzle = np.array(puzzle, dtype=np.long).reshape([-1]) puzzle = np.array(puzzle, dtype=np.long).reshape([-1])
model_path = 'ckpt' model_path = "ckpt"
if not os.path.exists(model_path): if not os.path.exists(model_path):
os.mkdir(model_path) os.mkdir(model_path)
model_filename = os.path.join(model_path, 'rrn-sudoku.pkl') model_filename = os.path.join(model_path, "rrn-sudoku.pkl")
if not os.path.exists(model_filename): if not os.path.exists(model_filename):
print('Downloading model...') print("Downloading model...")
url = 'https://data.dgl.ai/models/rrn-sudoku.pkl' url = "https://data.dgl.ai/models/rrn-sudoku.pkl"
urllib.request.urlretrieve(url, model_filename) urllib.request.urlretrieve(url, model_filename)
model = SudokuNN(num_steps=64, edge_drop=0.) model = SudokuNN(num_steps=64, edge_drop=0.0)
model.load_state_dict(torch.load(model_filename, map_location='cpu')) model.load_state_dict(torch.load(model_filename, map_location="cpu"))
model.eval() model.eval()
g = _basic_sudoku_graph() g = _basic_sudoku_graph()
...@@ -33,17 +33,17 @@ def solve_sudoku(puzzle): ...@@ -33,17 +33,17 @@ def solve_sudoku(puzzle):
rows = sudoku_indices // 9 rows = sudoku_indices // 9
cols = sudoku_indices % 9 cols = sudoku_indices % 9
g.ndata['row'] = torch.tensor(rows, dtype=torch.long) g.ndata["row"] = torch.tensor(rows, dtype=torch.long)
g.ndata['col'] = torch.tensor(cols, dtype=torch.long) g.ndata["col"] = torch.tensor(cols, dtype=torch.long)
g.ndata['q'] = torch.tensor(puzzle, dtype=torch.long) g.ndata["q"] = torch.tensor(puzzle, dtype=torch.long)
g.ndata['a'] = torch.tensor(puzzle, dtype=torch.long) g.ndata["a"] = torch.tensor(puzzle, dtype=torch.long)
pred, _ = model(g, False) pred, _ = model(g, False)
pred = pred.cpu().data.numpy().reshape([9, 9]) pred = pred.cpu().data.numpy().reshape([9, 9])
return pred return pred
if __name__ == '__main__': if __name__ == "__main__":
q = [ q = [
[9, 7, 0, 4, 0, 2, 0, 5, 3], [9, 7, 0, 4, 0, 2, 0, 5, 3],
[0, 4, 6, 0, 9, 0, 0, 0, 0], [0, 4, 6, 0, 9, 0, 0, 0, 0],
...@@ -53,7 +53,7 @@ if __name__ == '__main__': ...@@ -53,7 +53,7 @@ if __name__ == '__main__':
[0, 0, 2, 8, 0, 0, 0, 0, 0], [0, 0, 2, 8, 0, 0, 0, 0, 0],
[6, 0, 5, 1, 0, 7, 2, 0, 0], [6, 0, 5, 1, 0, 7, 2, 0, 0],
[0, 0, 0, 0, 6, 0, 7, 4, 0], [0, 0, 0, 0, 6, 0, 7, 4, 0],
[4, 3, 0, 2, 0, 9, 0, 6, 1] [4, 3, 0, 2, 0, 9, 0, 6, 1],
] ]
answer = solve_sudoku(q) answer = solve_sudoku(q)
......
from sudoku_data import sudoku_dataloader
import argparse import argparse
from sudoku import SudokuNN
import torch
from torch.optim import Adam
import os import os
import numpy as np import numpy as np
import torch
from sudoku import SudokuNN
from sudoku_data import sudoku_dataloader
from torch.optim import Adam
def main(args): def main(args):
if args.gpu < 0 or not torch.cuda.is_available(): if args.gpu < 0 or not torch.cuda.is_available():
device = torch.device('cpu') device = torch.device("cpu")
else: else:
device = torch.device('cuda', args.gpu) device = torch.device("cuda", args.gpu)
model = SudokuNN(num_steps=args.steps, edge_drop=args.edge_drop) model = SudokuNN(num_steps=args.steps, edge_drop=args.edge_drop)
...@@ -19,10 +20,12 @@ def main(args): ...@@ -19,10 +20,12 @@ def main(args):
if not os.path.exists(args.output_dir): if not os.path.exists(args.output_dir):
os.mkdir(args.output_dir) os.mkdir(args.output_dir)
model.to(device) model.to(device)
train_dataloader = sudoku_dataloader(args.batch_size, segment='train') train_dataloader = sudoku_dataloader(args.batch_size, segment="train")
dev_dataloader = sudoku_dataloader(args.batch_size, segment='valid') dev_dataloader = sudoku_dataloader(args.batch_size, segment="valid")
opt = Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) opt = Adam(
model.parameters(), lr=args.lr, weight_decay=args.weight_decay
)
best_dev_acc = 0.0 best_dev_acc = 0.0
for epoch in range(args.epochs): for epoch in range(args.epochs):
...@@ -43,7 +46,7 @@ def main(args): ...@@ -43,7 +46,7 @@ def main(args):
dev_res = [] dev_res = []
for g in dev_dataloader: for g in dev_dataloader:
g = g.to(device) g = g.to(device)
target = g.ndata['a'] target = g.ndata["a"]
target = target.view([-1, 81]) target = target.view([-1, 81])
with torch.no_grad(): with torch.no_grad():
...@@ -51,28 +54,35 @@ def main(args): ...@@ -51,28 +54,35 @@ def main(args):
preds = preds.view([-1, 81]) preds = preds.view([-1, 81])
for i in range(preds.size(0)): for i in range(preds.size(0)):
dev_res.append(int(torch.equal(preds[i, :], target[i, :]))) dev_res.append(
int(torch.equal(preds[i, :], target[i, :]))
)
dev_loss.append(loss.cpu().detach().data) dev_loss.append(loss.cpu().detach().data)
dev_acc = sum(dev_res) / len(dev_res) dev_acc = sum(dev_res) / len(dev_res)
print(f"Dev loss {np.mean(dev_loss)}, accuracy {dev_acc}") print(f"Dev loss {np.mean(dev_loss)}, accuracy {dev_acc}")
if dev_acc >= best_dev_acc: if dev_acc >= best_dev_acc:
torch.save(model.state_dict(), os.path.join(args.output_dir, 'model_best.bin')) torch.save(
model.state_dict(),
os.path.join(args.output_dir, "model_best.bin"),
)
best_dev_acc = dev_acc best_dev_acc = dev_acc
print(f"Best dev accuracy {best_dev_acc}\n") print(f"Best dev accuracy {best_dev_acc}\n")
torch.save(model.state_dict(), os.path.join(args.output_dir, 'model_final.bin')) torch.save(
model.state_dict(), os.path.join(args.output_dir, "model_final.bin")
)
if args.do_eval: if args.do_eval:
model_path = os.path.join(args.output_dir, 'model_best.bin') model_path = os.path.join(args.output_dir, "model_best.bin")
if not os.path.exists(model_path): if not os.path.exists(model_path):
raise FileNotFoundError("Saved model not Found!") raise FileNotFoundError("Saved model not Found!")
model.load_state_dict(torch.load(model_path)) model.load_state_dict(torch.load(model_path))
model.to(device) model.to(device)
test_dataloader = sudoku_dataloader(args.batch_size, segment='test') test_dataloader = sudoku_dataloader(args.batch_size, segment="test")
print("\n=========Test step========") print("\n=========Test step========")
model.eval() model.eval()
...@@ -80,7 +90,7 @@ def main(args): ...@@ -80,7 +90,7 @@ def main(args):
test_res = [] test_res = []
for g in test_dataloader: for g in test_dataloader:
g = g.to(device) g = g.to(device)
target = g.ndata['a'] target = g.ndata["a"]
target = target.view([-1, 81]) target = target.view([-1, 81])
with torch.no_grad(): with torch.no_grad():
...@@ -98,30 +108,44 @@ def main(args): ...@@ -98,30 +108,44 @@ def main(args):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Recurrent Relational Network on sudoku task.') parser = argparse.ArgumentParser(
parser.add_argument("--output_dir", type=str, default=None, required=True, description="Recurrent Relational Network on sudoku task."
help="The directory to save model") )
parser.add_argument("--do_train", default=False, action="store_true", parser.add_argument(
help="Train the model") "--output_dir",
parser.add_argument("--do_eval", default=False, action="store_true", type=str,
help="Evaluate the model on test data") default=None,
parser.add_argument("--epochs", type=int, default=100, required=True,
help="Number of training epochs") help="The directory to save model",
parser.add_argument("--batch_size", type=int, default=64, )
help="Batch size") parser.add_argument(
parser.add_argument("--edge_drop", type=float, default=0.4, "--do_train", default=False, action="store_true", help="Train the model"
help="Dropout rate at edges.") )
parser.add_argument("--steps", type=int, default=32, parser.add_argument(
help="Number of message passing steps.") "--do_eval",
parser.add_argument("--gpu", type=int, default=-1, default=False,
help="gpu") action="store_true",
parser.add_argument("--lr", type=float, default=2e-4, help="Evaluate the model on test data",
help="Learning rate") )
parser.add_argument("--weight_decay", type=float, default=1e-4, parser.add_argument(
help="weight decay (L2 penalty)") "--epochs", type=int, default=100, help="Number of training epochs"
)
parser.add_argument("--batch_size", type=int, default=64, help="Batch size")
parser.add_argument(
"--edge_drop", type=float, default=0.4, help="Dropout rate at edges."
)
parser.add_argument(
"--steps", type=int, default=32, help="Number of message passing steps."
)
parser.add_argument("--gpu", type=int, default=-1, help="gpu")
parser.add_argument("--lr", type=float, default=2e-4, help="Learning rate")
parser.add_argument(
"--weight_decay",
type=float,
default=1e-4,
help="weight decay (L2 penalty)",
)
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)
import json import json
import os import os
from copy import deepcopy from copy import deepcopy
from main import main, parse_args from main import main, parse_args
from utils import get_stats from utils import get_stats
def load_config(path="./grid_search_config.json"): def load_config(path="./grid_search_config.json"):
with open(path, "r") as f: with open(path, "r") as f:
return json.load(f) return json.load(f)
def run_experiments(args): def run_experiments(args):
res = [] res = []
for i in range(args.num_trials): for i in range(args.num_trials):
print("Trial {}/{}".format(i + 1, args.num_trials)) print("Trial {}/{}".format(i + 1, args.num_trials))
acc, _ = main(args) acc, _ = main(args)
res.append(acc) res.append(acc)
mean, err_bd = get_stats(res, conf_interval=True) mean, err_bd = get_stats(res, conf_interval=True)
return mean, err_bd return mean, err_bd
def grid_search(config:dict):
def grid_search(config: dict):
args = parse_args() args = parse_args()
results = {} results = {}
for d in config["dataset"]: for d in config["dataset"]:
args.dataset = d args.dataset = d
best_acc, err_bd = 0., 0. best_acc, err_bd = 0.0, 0.0
best_args = vars(args) best_args = vars(args)
for arch in config["arch"]: for arch in config["arch"]:
args.architecture = arch args.architecture = arch
...@@ -47,9 +51,10 @@ def grid_search(config:dict): ...@@ -47,9 +51,10 @@ def grid_search(config:dict):
args.output_path = "./output/{}.log".format(d) args.output_path = "./output/{}.log".format(d)
result = { result = {
"params": best_args, "params": best_args,
"result": "{:.4f}({:.4f})".format(best_acc, err_bd) "result": "{:.4f}({:.4f})".format(best_acc, err_bd),
} }
with open(args.output_path, "w") as f: with open(args.output_path, "w") as f:
json.dump(result, f, sort_keys=True, indent=4) json.dump(result, f, sort_keys=True, indent=4)
grid_search(load_config()) grid_search(load_config())
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from utils import get_batch_id, topk
import dgl import dgl
from dgl.nn import GraphConv, AvgPooling, MaxPooling from dgl.nn import AvgPooling, GraphConv, MaxPooling
from utils import topk, get_batch_id
class SAGPool(torch.nn.Module): class SAGPool(torch.nn.Module):
"""The Self-Attention Pooling layer in paper """The Self-Attention Pooling layer in paper
`Self Attention Graph Pooling <https://arxiv.org/pdf/1904.08082.pdf>` `Self Attention Graph Pooling <https://arxiv.org/pdf/1904.08082.pdf>`
Args: Args:
...@@ -18,16 +19,28 @@ class SAGPool(torch.nn.Module): ...@@ -18,16 +19,28 @@ class SAGPool(torch.nn.Module):
non_linearity (Callable, optional): The non-linearity function, a pytorch function. non_linearity (Callable, optional): The non-linearity function, a pytorch function.
(default: :obj:`torch.tanh`) (default: :obj:`torch.tanh`)
""" """
def __init__(self, in_dim:int, ratio=0.5, conv_op=GraphConv, non_linearity=torch.tanh):
def __init__(
self,
in_dim: int,
ratio=0.5,
conv_op=GraphConv,
non_linearity=torch.tanh,
):
super(SAGPool, self).__init__() super(SAGPool, self).__init__()
self.in_dim = in_dim self.in_dim = in_dim
self.ratio = ratio self.ratio = ratio
self.score_layer = conv_op(in_dim, 1) self.score_layer = conv_op(in_dim, 1)
self.non_linearity = non_linearity self.non_linearity = non_linearity
def forward(self, graph:dgl.DGLGraph, feature:torch.Tensor): def forward(self, graph: dgl.DGLGraph, feature: torch.Tensor):
score = self.score_layer(graph, feature).squeeze() score = self.score_layer(graph, feature).squeeze()
perm, next_batch_num_nodes = topk(score, self.ratio, get_batch_id(graph.batch_num_nodes()), graph.batch_num_nodes()) perm, next_batch_num_nodes = topk(
score,
self.ratio,
get_batch_id(graph.batch_num_nodes()),
graph.batch_num_nodes(),
)
feature = feature[perm] * self.non_linearity(score[perm]).view(-1, 1) feature = feature[perm] * self.non_linearity(score[perm]).view(-1, 1)
graph = dgl.node_subgraph(graph, perm) graph = dgl.node_subgraph(graph, perm)
...@@ -37,7 +50,7 @@ class SAGPool(torch.nn.Module): ...@@ -37,7 +50,7 @@ class SAGPool(torch.nn.Module):
# Since global pooling has nothing to do with 'batch_num_edges', # Since global pooling has nothing to do with 'batch_num_edges',
# we can leave it to be None or unchanged. # we can leave it to be None or unchanged.
graph.set_batch_num_nodes(next_batch_num_nodes) graph.set_batch_num_nodes(next_batch_num_nodes)
return graph, feature, perm return graph, feature, perm
...@@ -45,15 +58,18 @@ class ConvPoolBlock(torch.nn.Module): ...@@ -45,15 +58,18 @@ class ConvPoolBlock(torch.nn.Module):
"""A combination of GCN layer and SAGPool layer, """A combination of GCN layer and SAGPool layer,
followed by a concatenated (mean||sum) readout operation. followed by a concatenated (mean||sum) readout operation.
""" """
def __init__(self, in_dim:int, out_dim:int, pool_ratio=0.8):
def __init__(self, in_dim: int, out_dim: int, pool_ratio=0.8):
super(ConvPoolBlock, self).__init__() super(ConvPoolBlock, self).__init__()
self.conv = GraphConv(in_dim, out_dim) self.conv = GraphConv(in_dim, out_dim)
self.pool = SAGPool(out_dim, ratio=pool_ratio) self.pool = SAGPool(out_dim, ratio=pool_ratio)
self.avgpool = AvgPooling() self.avgpool = AvgPooling()
self.maxpool = MaxPooling() self.maxpool = MaxPooling()
def forward(self, graph, feature): def forward(self, graph, feature):
out = F.relu(self.conv(graph, feature)) out = F.relu(self.conv(graph, feature))
graph, out, _ = self.pool(graph, out) graph, out, _ = self.pool(graph, out)
g_out = torch.cat([self.avgpool(graph, out), self.maxpool(graph, out)], dim=-1) g_out = torch.cat(
[self.avgpool(graph, out), self.maxpool(graph, out)], dim=-1
)
return graph, out, g_out return graph, out, g_out
...@@ -3,54 +3,78 @@ import json ...@@ -3,54 +3,78 @@ import json
import logging import logging
import os import os
from time import time from time import time
import dgl
import torch import torch
import torch.nn import torch.nn
import torch.nn.functional as F import torch.nn.functional as F
from dgl.data import LegacyTUDataset
from dgl.dataloading import GraphDataLoader
from torch.utils.data import random_split
from network import get_sag_network from network import get_sag_network
from torch.utils.data import random_split
from utils import get_stats from utils import get_stats
import dgl
from dgl.data import LegacyTUDataset
from dgl.dataloading import GraphDataLoader
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description="Self-Attention Graph Pooling") parser = argparse.ArgumentParser(description="Self-Attention Graph Pooling")
parser.add_argument("--dataset", type=str, default="DD", parser.add_argument(
choices=["DD", "PROTEINS", "NCI1", "NCI109", "Mutagenicity"], "--dataset",
help="DD/PROTEINS/NCI1/NCI109/Mutagenicity") type=str,
parser.add_argument("--batch_size", type=int, default=128, default="DD",
help="batch size") choices=["DD", "PROTEINS", "NCI1", "NCI109", "Mutagenicity"],
parser.add_argument("--lr", type=float, default=5e-4, help="DD/PROTEINS/NCI1/NCI109/Mutagenicity",
help="learning rate") )
parser.add_argument("--weight_decay", type=float, default=1e-4, parser.add_argument(
help="weight decay") "--batch_size", type=int, default=128, help="batch size"
parser.add_argument("--pool_ratio", type=float, default=0.5, )
help="pooling ratio") parser.add_argument("--lr", type=float, default=5e-4, help="learning rate")
parser.add_argument("--hid_dim", type=int, default=128, parser.add_argument(
help="hidden size") "--weight_decay", type=float, default=1e-4, help="weight decay"
parser.add_argument("--dropout", type=float, default=0.5, )
help="dropout ratio") parser.add_argument(
parser.add_argument("--epochs", type=int, default=100000, "--pool_ratio", type=float, default=0.5, help="pooling ratio"
help="max number of training epochs") )
parser.add_argument("--patience", type=int, default=50, parser.add_argument("--hid_dim", type=int, default=128, help="hidden size")
help="patience for early stopping") parser.add_argument(
parser.add_argument("--device", type=int, default=-1, "--dropout", type=float, default=0.5, help="dropout ratio"
help="device id, -1 for cpu") )
parser.add_argument("--architecture", type=str, default="hierarchical", parser.add_argument(
choices=["hierarchical", "global"], "--epochs",
help="model architecture") type=int,
parser.add_argument("--dataset_path", type=str, default="./dataset", default=100000,
help="path to dataset") help="max number of training epochs",
parser.add_argument("--conv_layers", type=int, default=3, )
help="number of conv layers") parser.add_argument(
parser.add_argument("--print_every", type=int, default=10, "--patience", type=int, default=50, help="patience for early stopping"
help="print trainlog every k epochs, -1 for silent training") )
parser.add_argument("--num_trials", type=int, default=1, parser.add_argument(
help="number of trials") "--device", type=int, default=-1, help="device id, -1 for cpu"
)
parser.add_argument(
"--architecture",
type=str,
default="hierarchical",
choices=["hierarchical", "global"],
help="model architecture",
)
parser.add_argument(
"--dataset_path", type=str, default="./dataset", help="path to dataset"
)
parser.add_argument(
"--conv_layers", type=int, default=3, help="number of conv layers"
)
parser.add_argument(
"--print_every",
type=int,
default=10,
help="print trainlog every k epochs, -1 for silent training",
)
parser.add_argument(
"--num_trials", type=int, default=1, help="number of trials"
)
parser.add_argument("--output_path", type=str, default="./output") parser.add_argument("--output_path", type=str, default="./output")
args = parser.parse_args() args = parser.parse_args()
# device # device
...@@ -69,15 +93,21 @@ def parse_args(): ...@@ -69,15 +93,21 @@ def parse_args():
if not os.path.exists(args.output_path): if not os.path.exists(args.output_path):
os.makedirs(args.output_path) os.makedirs(args.output_path)
name = "Data={}_Hidden={}_Arch={}_Pool={}_WeightDecay={}_Lr={}.log".format( name = "Data={}_Hidden={}_Arch={}_Pool={}_WeightDecay={}_Lr={}.log".format(
args.dataset, args.hid_dim, args.architecture, args.pool_ratio, args.weight_decay, args.lr) args.dataset,
args.hid_dim,
args.architecture,
args.pool_ratio,
args.weight_decay,
args.lr,
)
args.output_path = os.path.join(args.output_path, name) args.output_path = os.path.join(args.output_path, name)
return args return args
def train(model:torch.nn.Module, optimizer, trainloader, device): def train(model: torch.nn.Module, optimizer, trainloader, device):
model.train() model.train()
total_loss = 0. total_loss = 0.0
num_batches = len(trainloader) num_batches = len(trainloader)
for batch in trainloader: for batch in trainloader:
optimizer.zero_grad() optimizer.zero_grad()
...@@ -90,15 +120,15 @@ def train(model:torch.nn.Module, optimizer, trainloader, device): ...@@ -90,15 +120,15 @@ def train(model:torch.nn.Module, optimizer, trainloader, device):
optimizer.step() optimizer.step()
total_loss += loss.item() total_loss += loss.item()
return total_loss / num_batches return total_loss / num_batches
@torch.no_grad() @torch.no_grad()
def test(model:torch.nn.Module, loader, device): def test(model: torch.nn.Module, loader, device):
model.eval() model.eval()
correct = 0. correct = 0.0
loss = 0. loss = 0.0
num_graphs = 0 num_graphs = 0
for batch in loader: for batch in loader:
batch_graphs, batch_labels = batch batch_graphs, batch_labels = batch
...@@ -124,29 +154,45 @@ def main(args): ...@@ -124,29 +154,45 @@ def main(args):
num_training = int(len(dataset) * 0.8) num_training = int(len(dataset) * 0.8)
num_val = int(len(dataset) * 0.1) num_val = int(len(dataset) * 0.1)
num_test = len(dataset) - num_val - num_training num_test = len(dataset) - num_val - num_training
train_set, val_set, test_set = random_split(dataset, [num_training, num_val, num_test]) train_set, val_set, test_set = random_split(
dataset, [num_training, num_val, num_test]
train_loader = GraphDataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=6) )
val_loader = GraphDataLoader(val_set, batch_size=args.batch_size, num_workers=2)
test_loader = GraphDataLoader(test_set, batch_size=args.batch_size, num_workers=2) train_loader = GraphDataLoader(
train_set, batch_size=args.batch_size, shuffle=True, num_workers=6
)
val_loader = GraphDataLoader(
val_set, batch_size=args.batch_size, num_workers=2
)
test_loader = GraphDataLoader(
test_set, batch_size=args.batch_size, num_workers=2
)
device = torch.device(args.device) device = torch.device(args.device)
# Step 2: Create model =================================================================== # # Step 2: Create model =================================================================== #
num_feature, num_classes, _ = dataset.statistics() num_feature, num_classes, _ = dataset.statistics()
model_op = get_sag_network(args.architecture) model_op = get_sag_network(args.architecture)
model = model_op(in_dim=num_feature, hid_dim=args.hid_dim, out_dim=num_classes, model = model_op(
num_convs=args.conv_layers, pool_ratio=args.pool_ratio, dropout=args.dropout).to(device) in_dim=num_feature,
hid_dim=args.hid_dim,
out_dim=num_classes,
num_convs=args.conv_layers,
pool_ratio=args.pool_ratio,
dropout=args.dropout,
).to(device)
args.num_feature = int(num_feature) args.num_feature = int(num_feature)
args.num_classes = int(num_classes) args.num_classes = int(num_classes)
# Step 3: Create training components ===================================================== # # Step 3: Create training components ===================================================== #
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) optimizer = torch.optim.Adam(
model.parameters(), lr=args.lr, weight_decay=args.weight_decay
)
# Step 4: training epoches =============================================================== # # Step 4: training epoches =============================================================== #
bad_cound = 0 bad_cound = 0
best_val_loss = float("inf") best_val_loss = float("inf")
final_test_acc = 0. final_test_acc = 0.0
best_epoch = 0 best_epoch = 0
train_times = [] train_times = []
for e in range(args.epochs): for e in range(args.epochs):
...@@ -164,11 +210,17 @@ def main(args): ...@@ -164,11 +210,17 @@ def main(args):
bad_cound += 1 bad_cound += 1
if bad_cound >= args.patience: if bad_cound >= args.patience:
break break
if (e + 1) % args.print_every == 0: if (e + 1) % args.print_every == 0:
log_format = "Epoch {}: loss={:.4f}, val_acc={:.4f}, final_test_acc={:.4f}" log_format = (
"Epoch {}: loss={:.4f}, val_acc={:.4f}, final_test_acc={:.4f}"
)
print(log_format.format(e + 1, train_loss, val_acc, final_test_acc)) print(log_format.format(e + 1, train_loss, val_acc, final_test_acc))
print("Best Epoch {}, final test acc {:.4f}".format(best_epoch, final_test_acc)) print(
"Best Epoch {}, final test acc {:.4f}".format(
best_epoch, final_test_acc
)
)
return final_test_acc, sum(train_times) / len(train_times) return final_test_acc, sum(train_times) / len(train_times)
...@@ -185,9 +237,11 @@ if __name__ == "__main__": ...@@ -185,9 +237,11 @@ if __name__ == "__main__":
mean, err_bd = get_stats(res) mean, err_bd = get_stats(res)
print("mean acc: {:.4f}, error bound: {:.4f}".format(mean, err_bd)) print("mean acc: {:.4f}, error bound: {:.4f}".format(mean, err_bd))
out_dict = {"hyper-parameters": vars(args), out_dict = {
"result": "{:.4f}(+-{:.4f})".format(mean, err_bd), "hyper-parameters": vars(args),
"train_time": "{:.4f}".format(sum(train_times) / len(train_times))} "result": "{:.4f}(+-{:.4f})".format(mean, err_bd),
"train_time": "{:.4f}".format(sum(train_times) / len(train_times)),
}
with open(args.output_path, "w") as f: with open(args.output_path, "w") as f:
json.dump(out_dict, f, sort_keys=True, indent=4) json.dump(out_dict, f, sort_keys=True, indent=4)
import torch import torch
import torch.nn import torch.nn
import torch.nn.functional as F import torch.nn.functional as F
import dgl
from dgl.nn import GraphConv, AvgPooling, MaxPooling
from layer import ConvPoolBlock, SAGPool from layer import ConvPoolBlock, SAGPool
import dgl
from dgl.nn import AvgPooling, GraphConv, MaxPooling
class SAGNetworkHierarchical(torch.nn.Module): class SAGNetworkHierarchical(torch.nn.Module):
"""The Self-Attention Graph Pooling Network with hierarchical readout in paper """The Self-Attention Graph Pooling Network with hierarchical readout in paper
...@@ -21,25 +21,35 @@ class SAGNetworkHierarchical(torch.nn.Module): ...@@ -21,25 +21,35 @@ class SAGNetworkHierarchical(torch.nn.Module):
remain after pooling. (default: :obj:`0.5`) remain after pooling. (default: :obj:`0.5`)
dropout (float, optional): The dropout ratio for each layer. (default: 0) dropout (float, optional): The dropout ratio for each layer. (default: 0)
""" """
def __init__(self, in_dim:int, hid_dim:int, out_dim:int, num_convs=3,
pool_ratio:float=0.5, dropout:float=0.0): def __init__(
self,
in_dim: int,
hid_dim: int,
out_dim: int,
num_convs=3,
pool_ratio: float = 0.5,
dropout: float = 0.0,
):
super(SAGNetworkHierarchical, self).__init__() super(SAGNetworkHierarchical, self).__init__()
self.dropout = dropout self.dropout = dropout
self.num_convpools = num_convs self.num_convpools = num_convs
convpools = [] convpools = []
for i in range(num_convs): for i in range(num_convs):
_i_dim = in_dim if i == 0 else hid_dim _i_dim = in_dim if i == 0 else hid_dim
_o_dim = hid_dim _o_dim = hid_dim
convpools.append(ConvPoolBlock(_i_dim, _o_dim, pool_ratio=pool_ratio)) convpools.append(
ConvPoolBlock(_i_dim, _o_dim, pool_ratio=pool_ratio)
)
self.convpools = torch.nn.ModuleList(convpools) self.convpools = torch.nn.ModuleList(convpools)
self.lin1 = torch.nn.Linear(hid_dim * 2, hid_dim) self.lin1 = torch.nn.Linear(hid_dim * 2, hid_dim)
self.lin2 = torch.nn.Linear(hid_dim, hid_dim // 2) self.lin2 = torch.nn.Linear(hid_dim, hid_dim // 2)
self.lin3 = torch.nn.Linear(hid_dim // 2, out_dim) self.lin3 = torch.nn.Linear(hid_dim // 2, out_dim)
def forward(self, graph:dgl.DGLGraph): def forward(self, graph: dgl.DGLGraph):
feat = graph.ndata["feat"] feat = graph.ndata["feat"]
final_readout = None final_readout = None
...@@ -49,12 +59,12 @@ class SAGNetworkHierarchical(torch.nn.Module): ...@@ -49,12 +59,12 @@ class SAGNetworkHierarchical(torch.nn.Module):
final_readout = readout final_readout = readout
else: else:
final_readout = final_readout + readout final_readout = final_readout + readout
feat = F.relu(self.lin1(final_readout)) feat = F.relu(self.lin1(final_readout))
feat = F.dropout(feat, p=self.dropout, training=self.training) feat = F.dropout(feat, p=self.dropout, training=self.training)
feat = F.relu(self.lin2(feat)) feat = F.relu(self.lin2(feat))
feat = F.log_softmax(self.lin3(feat), dim=-1) feat = F.log_softmax(self.lin3(feat), dim=-1)
return feat return feat
...@@ -72,8 +82,16 @@ class SAGNetworkGlobal(torch.nn.Module): ...@@ -72,8 +82,16 @@ class SAGNetworkGlobal(torch.nn.Module):
remain after pooling. (default: :obj:`0.5`) remain after pooling. (default: :obj:`0.5`)
dropout (float, optional): The dropout ratio for each layer. (default: 0) dropout (float, optional): The dropout ratio for each layer. (default: 0)
""" """
def __init__(self, in_dim:int, hid_dim:int, out_dim:int, num_convs=3,
pool_ratio:float=0.5, dropout:float=0.0): def __init__(
self,
in_dim: int,
hid_dim: int,
out_dim: int,
num_convs=3,
pool_ratio: float = 0.5,
dropout: float = 0.0,
):
super(SAGNetworkGlobal, self).__init__() super(SAGNetworkGlobal, self).__init__()
self.dropout = dropout self.dropout = dropout
self.num_convs = num_convs self.num_convs = num_convs
...@@ -93,18 +111,21 @@ class SAGNetworkGlobal(torch.nn.Module): ...@@ -93,18 +111,21 @@ class SAGNetworkGlobal(torch.nn.Module):
self.lin1 = torch.nn.Linear(concat_dim * 2, hid_dim) self.lin1 = torch.nn.Linear(concat_dim * 2, hid_dim)
self.lin2 = torch.nn.Linear(hid_dim, hid_dim // 2) self.lin2 = torch.nn.Linear(hid_dim, hid_dim // 2)
self.lin3 = torch.nn.Linear(hid_dim // 2, out_dim) self.lin3 = torch.nn.Linear(hid_dim // 2, out_dim)
def forward(self, graph:dgl.DGLGraph): def forward(self, graph: dgl.DGLGraph):
feat = graph.ndata["feat"] feat = graph.ndata["feat"]
conv_res = [] conv_res = []
for i in range(self.num_convs): for i in range(self.num_convs):
feat = self.convs[i](graph, feat) feat = self.convs[i](graph, feat)
conv_res.append(feat) conv_res.append(feat)
conv_res = torch.cat(conv_res, dim=-1) conv_res = torch.cat(conv_res, dim=-1)
graph, feat, _ = self.pool(graph, conv_res) graph, feat, _ = self.pool(graph, conv_res)
feat = torch.cat([self.avg_readout(graph, feat), self.max_readout(graph, feat)], dim=-1) feat = torch.cat(
[self.avg_readout(graph, feat), self.max_readout(graph, feat)],
dim=-1,
)
feat = F.relu(self.lin1(feat)) feat = F.relu(self.lin1(feat))
feat = F.dropout(feat, p=self.dropout, training=self.training) feat = F.dropout(feat, p=self.dropout, training=self.training)
...@@ -114,10 +135,12 @@ class SAGNetworkGlobal(torch.nn.Module): ...@@ -114,10 +135,12 @@ class SAGNetworkGlobal(torch.nn.Module):
return feat return feat
def get_sag_network(net_type:str="hierarchical"): def get_sag_network(net_type: str = "hierarchical"):
if net_type == "hierarchical": if net_type == "hierarchical":
return SAGNetworkHierarchical return SAGNetworkHierarchical
elif net_type == "global": elif net_type == "global":
return SAGNetworkGlobal return SAGNetworkGlobal
else: else:
raise ValueError("SAGNetwork type {} is not supported.".format(net_type)) raise ValueError(
"SAGNetwork type {} is not supported.".format(net_type)
)
import torch
import logging import logging
from scipy.stats import t
import math import math
import torch
from scipy.stats import t
def get_stats(array, conf_interval=False, name=None, stdout=False, logout=False): def get_stats(
array, conf_interval=False, name=None, stdout=False, logout=False
):
"""Compute mean and standard deviation from an numerical array """Compute mean and standard deviation from an numerical array
Args: Args:
array (array like obj): The numerical array, this array can be array (array like obj): The numerical array, this array can be
convert to :obj:`torch.Tensor`. convert to :obj:`torch.Tensor`.
conf_interval (bool, optional): If True, compute the confidence interval bound (95%) conf_interval (bool, optional): If True, compute the confidence interval bound (95%)
instead of the std value. (default: :obj:`False`) instead of the std value. (default: :obj:`False`)
name (str, optional): The name of this numerical array, for log usage. name (str, optional): The name of this numerical array, for log usage.
(default: :obj:`None`) (default: :obj:`None`)
stdout (bool, optional): Whether to output result to the terminal. stdout (bool, optional): Whether to output result to the terminal.
(default: :obj:`False`) (default: :obj:`False`)
logout (bool, optional): Whether to output result via logging module. logout (bool, optional): Whether to output result via logging module.
(default: :obj:`False`) (default: :obj:`False`)
...@@ -29,7 +32,7 @@ def get_stats(array, conf_interval=False, name=None, stdout=False, logout=False) ...@@ -29,7 +32,7 @@ def get_stats(array, conf_interval=False, name=None, stdout=False, logout=False)
if conf_interval: if conf_interval:
n = array.size(0) n = array.size(0)
se = std / (math.sqrt(n) + eps) se = std / (math.sqrt(n) + eps)
t_value = t.ppf(0.975, df=n-1) t_value = t.ppf(0.975, df=n - 1)
err_bound = t_value * se err_bound = t_value * se
else: else:
err_bound = std err_bound = std
...@@ -46,7 +49,7 @@ def get_stats(array, conf_interval=False, name=None, stdout=False, logout=False) ...@@ -46,7 +49,7 @@ def get_stats(array, conf_interval=False, name=None, stdout=False, logout=False)
return center, err_bound return center, err_bound
def get_batch_id(num_nodes:torch.Tensor): def get_batch_id(num_nodes: torch.Tensor):
"""Convert the num_nodes array obtained from batch graph to batch_id array """Convert the num_nodes array obtained from batch graph to batch_id array
for each node. for each node.
...@@ -57,12 +60,19 @@ def get_batch_id(num_nodes:torch.Tensor): ...@@ -57,12 +60,19 @@ def get_batch_id(num_nodes:torch.Tensor):
batch_size = num_nodes.size(0) batch_size = num_nodes.size(0)
batch_ids = [] batch_ids = []
for i in range(batch_size): for i in range(batch_size):
item = torch.full((num_nodes[i],), i, dtype=torch.long, device=num_nodes.device) item = torch.full(
(num_nodes[i],), i, dtype=torch.long, device=num_nodes.device
)
batch_ids.append(item) batch_ids.append(item)
return torch.cat(batch_ids) return torch.cat(batch_ids)
def topk(x:torch.Tensor, ratio:float, batch_id:torch.Tensor, num_nodes:torch.Tensor): def topk(
x: torch.Tensor,
ratio: float,
batch_id: torch.Tensor,
num_nodes: torch.Tensor,
):
"""The top-k pooling method. Given a graph batch, this method will pool out some """The top-k pooling method. Given a graph batch, this method will pool out some
nodes from input node feature tensor for each graph according to the given ratio. nodes from input node feature tensor for each graph according to the given ratio.
...@@ -72,21 +82,23 @@ def topk(x:torch.Tensor, ratio:float, batch_id:torch.Tensor, num_nodes:torch.Ten ...@@ -72,21 +82,23 @@ def topk(x:torch.Tensor, ratio:float, batch_id:torch.Tensor, num_nodes:torch.Ten
tensor will be pooled out. tensor will be pooled out.
batch_id (torch.Tensor): The batch_id of each element in the input tensor. batch_id (torch.Tensor): The batch_id of each element in the input tensor.
num_nodes (torch.Tensor): The number of nodes of each graph in batch. num_nodes (torch.Tensor): The number of nodes of each graph in batch.
Returns: Returns:
perm (torch.Tensor): The index in batch to be kept. perm (torch.Tensor): The index in batch to be kept.
k (torch.Tensor): The remaining number of nodes for each graph. k (torch.Tensor): The remaining number of nodes for each graph.
""" """
batch_size, max_num_nodes = num_nodes.size(0), num_nodes.max().item() batch_size, max_num_nodes = num_nodes.size(0), num_nodes.max().item()
cum_num_nodes = torch.cat( cum_num_nodes = torch.cat(
[num_nodes.new_zeros(1), [num_nodes.new_zeros(1), num_nodes.cumsum(dim=0)[:-1]], dim=0
num_nodes.cumsum(dim=0)[:-1]], dim=0) )
index = torch.arange(batch_id.size(0), dtype=torch.long, device=x.device) index = torch.arange(batch_id.size(0), dtype=torch.long, device=x.device)
index = (index - cum_num_nodes[batch_id]) + (batch_id * max_num_nodes) index = (index - cum_num_nodes[batch_id]) + (batch_id * max_num_nodes)
dense_x = x.new_full((batch_size * max_num_nodes, ), torch.finfo(x.dtype).min) dense_x = x.new_full(
(batch_size * max_num_nodes,), torch.finfo(x.dtype).min
)
dense_x[index] = x dense_x[index] = x
dense_x = dense_x.view(batch_size, max_num_nodes) dense_x = dense_x.view(batch_size, max_num_nodes)
...@@ -96,8 +108,10 @@ def topk(x:torch.Tensor, ratio:float, batch_id:torch.Tensor, num_nodes:torch.Ten ...@@ -96,8 +108,10 @@ def topk(x:torch.Tensor, ratio:float, batch_id:torch.Tensor, num_nodes:torch.Ten
k = (ratio * num_nodes.to(torch.float)).ceil().to(torch.long) k = (ratio * num_nodes.to(torch.float)).ceil().to(torch.long)
mask = [ mask = [
torch.arange(k[i], dtype=torch.long, device=x.device) + torch.arange(k[i], dtype=torch.long, device=x.device)
i * max_num_nodes for i in range(batch_size)] + i * max_num_nodes
for i in range(batch_size)
]
mask = torch.cat(mask, dim=0) mask = torch.cat(mask, dim=0)
perm = perm[mask] perm = perm[mask]
......
import logging import logging
import time
import os import os
import time
def _transform_log_level(str_level): def _transform_log_level(str_level):
if str_level == 'info': if str_level == "info":
return logging.INFO return logging.INFO
elif str_level == 'warning': elif str_level == "warning":
return logging.WARNING return logging.WARNING
elif str_level == 'critical': elif str_level == "critical":
return logging.CRITICAL return logging.CRITICAL
elif str_level == 'debug': elif str_level == "debug":
return logging.DEBUG return logging.DEBUG
elif str_level == 'error': elif str_level == "error":
return logging.ERROR return logging.ERROR
else: else:
raise KeyError('Log level error') raise KeyError("Log level error")
class LightLogging(object): class LightLogging(object):
def __init__(self, log_path=None, log_name='lightlog', log_level='debug'): def __init__(self, log_path=None, log_name="lightlog", log_level="debug"):
log_level = _transform_log_level(log_level) log_level = _transform_log_level(log_level)
if log_path: if log_path:
if not log_path.endswith('/'): if not log_path.endswith("/"):
log_path += '/' log_path += "/"
if not os.path.exists(log_path): if not os.path.exists(log_path):
os.mkdir(log_path) os.mkdir(log_path)
if log_name.endswith('-') or log_name.endswith('_'): if log_name.endswith("-") or log_name.endswith("_"):
log_name = log_path+log_name + time.strftime('%Y-%m-%d-%H:%M', time.localtime(time.time())) + '.log' log_name = (
log_path
+ log_name
+ time.strftime(
"%Y-%m-%d-%H:%M", time.localtime(time.time())
)
+ ".log"
)
else: else:
log_name = log_path+log_name + '_' + time.strftime('%Y-%m-%d-%H-%M', time.localtime(time.time())) + '.log' log_name = (
log_path
+ log_name
+ "_"
+ time.strftime(
"%Y-%m-%d-%H-%M", time.localtime(time.time())
)
+ ".log"
)
logging.basicConfig(level=log_level, logging.basicConfig(
format="%(asctime)s %(levelname)s: %(message)s", level=log_level,
datefmt='%Y-%m-%d-%H:%M', format="%(asctime)s %(levelname)s: %(message)s",
handlers=[ datefmt="%Y-%m-%d-%H:%M",
logging.FileHandler(log_name, mode='w'), handlers=[
logging.StreamHandler() logging.FileHandler(log_name, mode="w"),
]) logging.StreamHandler(),
logging.info('Start Logging') ],
logging.info('Log file path: {}'.format(log_name)) )
logging.info("Start Logging")
logging.info("Log file path: {}".format(log_name))
else: else:
logging.basicConfig(level=log_level, logging.basicConfig(
format="%(asctime)s %(levelname)s: %(message)s", level=log_level,
datefmt='%Y-%m-%d-%H:%M', format="%(asctime)s %(levelname)s: %(message)s",
handlers=[ datefmt="%Y-%m-%d-%H:%M",
logging.StreamHandler() handlers=[logging.StreamHandler()],
]) )
logging.info('Start Logging') logging.info("Start Logging")
def debug(self, msg): def debug(self, msg):
logging.debug(msg) logging.debug(msg)
...@@ -66,4 +83,4 @@ class LightLogging(object): ...@@ -66,4 +83,4 @@ class LightLogging(object):
logging.warning(msg) logging.warning(msg)
def error(self, msg): def error(self, msg):
logging.error(msg) logging.error(msg)
\ No newline at end of file
import time import time
from tqdm import tqdm
import numpy as np import numpy as np
import torch import torch
import torch.multiprocessing
from logger import LightLogging
from model import DGCNN, GCN
from sampler import SEALData
from torch.nn import BCEWithLogitsLoss from torch.nn import BCEWithLogitsLoss
from dgl import NID, EID from tqdm import tqdm
from utils import evaluate_hits, load_ogb_dataset, parse_arguments
from dgl import EID, NID
from dgl.dataloading import GraphDataLoader from dgl.dataloading import GraphDataLoader
from utils import parse_arguments
from utils import load_ogb_dataset, evaluate_hits
from sampler import SEALData
from model import GCN, DGCNN
from logger import LightLogging
import torch.multiprocessing
torch.multiprocessing.set_sharing_strategy('file_system')
''' torch.multiprocessing.set_sharing_strategy("file_system")
"""
Part of the code are adapted from Part of the code are adapted from
https://github.com/facebookresearch/SEAL_OGB https://github.com/facebookresearch/SEAL_OGB
''' """
def train(model, dataloader, loss_fn, optimizer, device, num_graphs=32, total_graphs=None): def train(
model,
dataloader,
loss_fn,
optimizer,
device,
num_graphs=32,
total_graphs=None,
):
model.train() model.train()
total_loss = 0 total_loss = 0
...@@ -27,7 +37,7 @@ def train(model, dataloader, loss_fn, optimizer, device, num_graphs=32, total_gr ...@@ -27,7 +37,7 @@ def train(model, dataloader, loss_fn, optimizer, device, num_graphs=32, total_gr
g = g.to(device) g = g.to(device)
labels = labels.to(device) labels = labels.to(device)
optimizer.zero_grad() optimizer.zero_grad()
logits = model(g, g.ndata['z'], g.ndata[NID], g.edata[EID]) logits = model(g, g.ndata["z"], g.ndata[NID], g.edata[EID])
loss = loss_fn(logits, labels) loss = loss_fn(logits, labels)
loss.backward() loss.backward()
optimizer.step() optimizer.step()
...@@ -43,7 +53,7 @@ def evaluate(model, dataloader, device): ...@@ -43,7 +53,7 @@ def evaluate(model, dataloader, device):
y_pred, y_true = [], [] y_pred, y_true = [], []
for g, labels in tqdm(dataloader, ncols=100): for g, labels in tqdm(dataloader, ncols=100):
g = g.to(device) g = g.to(device)
logits = model(g, g.ndata['z'], g.ndata[NID], g.edata[EID]) logits = model(g, g.ndata["z"], g.ndata[NID], g.edata[EID])
y_pred.append(logits.view(-1).cpu()) y_pred.append(logits.view(-1).cpu())
y_true.append(labels.view(-1).cpu().to(torch.float)) y_true.append(labels.view(-1).cpu().to(torch.float))
...@@ -62,7 +72,7 @@ def main(args, print_fn=print): ...@@ -62,7 +72,7 @@ def main(args, print_fn=print):
else: else:
torch.manual_seed(123) torch.manual_seed(123)
# Load dataset # Load dataset
if args.dataset.startswith('ogbl'): if args.dataset.startswith("ogbl"):
graph, split_edge = load_ogb_dataset(args.dataset) graph, split_edge = load_ogb_dataset(args.dataset)
else: else:
raise NotImplementedError raise NotImplementedError
...@@ -71,11 +81,11 @@ def main(args, print_fn=print): ...@@ -71,11 +81,11 @@ def main(args, print_fn=print):
# set gpu # set gpu
if args.gpu_id >= 0 and torch.cuda.is_available(): if args.gpu_id >= 0 and torch.cuda.is_available():
device = 'cuda:{}'.format(args.gpu_id) device = "cuda:{}".format(args.gpu_id)
else: else:
device = 'cpu' device = "cpu"
if args.dataset == 'ogbl-collab': if args.dataset == "ogbl-collab":
# ogbl-collab dataset is multi-edge graph # ogbl-collab dataset is multi-edge graph
use_coalesce = True use_coalesce = True
else: else:
...@@ -83,94 +93,135 @@ def main(args, print_fn=print): ...@@ -83,94 +93,135 @@ def main(args, print_fn=print):
# Generate positive and negative edges and corresponding labels # Generate positive and negative edges and corresponding labels
# Sampling subgraphs and generate node labeling features # Sampling subgraphs and generate node labeling features
seal_data = SEALData(g=graph, split_edge=split_edge, hop=args.hop, neg_samples=args.neg_samples, seal_data = SEALData(
subsample_ratio=args.subsample_ratio, use_coalesce=use_coalesce, prefix=args.dataset, g=graph,
save_dir=args.save_dir, num_workers=args.num_workers, print_fn=print_fn) split_edge=split_edge,
node_attribute = seal_data.ndata['feat'] hop=args.hop,
edge_weight = seal_data.edata['weight'].float() neg_samples=args.neg_samples,
subsample_ratio=args.subsample_ratio,
train_data = seal_data('train') use_coalesce=use_coalesce,
val_data = seal_data('valid') prefix=args.dataset,
test_data = seal_data('test') save_dir=args.save_dir,
num_workers=args.num_workers,
print_fn=print_fn,
)
node_attribute = seal_data.ndata["feat"]
edge_weight = seal_data.edata["weight"].float()
train_data = seal_data("train")
val_data = seal_data("valid")
test_data = seal_data("test")
train_graphs = len(train_data.graph_list) train_graphs = len(train_data.graph_list)
# Set data loader # Set data loader
train_loader = GraphDataLoader(train_data, batch_size=args.batch_size, num_workers=args.num_workers) train_loader = GraphDataLoader(
val_loader = GraphDataLoader(val_data, batch_size=args.batch_size, num_workers=args.num_workers) train_data, batch_size=args.batch_size, num_workers=args.num_workers
test_loader = GraphDataLoader(test_data, batch_size=args.batch_size, num_workers=args.num_workers) )
val_loader = GraphDataLoader(
val_data, batch_size=args.batch_size, num_workers=args.num_workers
)
test_loader = GraphDataLoader(
test_data, batch_size=args.batch_size, num_workers=args.num_workers
)
# set model # set model
if args.model == 'gcn': if args.model == "gcn":
model = GCN(num_layers=args.num_layers, model = GCN(
hidden_units=args.hidden_units, num_layers=args.num_layers,
gcn_type=args.gcn_type, hidden_units=args.hidden_units,
pooling_type=args.pooling, gcn_type=args.gcn_type,
node_attributes=node_attribute, pooling_type=args.pooling,
edge_weights=edge_weight, node_attributes=node_attribute,
node_embedding=None, edge_weights=edge_weight,
use_embedding=True, node_embedding=None,
num_nodes=num_nodes, use_embedding=True,
dropout=args.dropout) num_nodes=num_nodes,
elif args.model == 'dgcnn': dropout=args.dropout,
model = DGCNN(num_layers=args.num_layers, )
hidden_units=args.hidden_units, elif args.model == "dgcnn":
k=args.sort_k, model = DGCNN(
gcn_type=args.gcn_type, num_layers=args.num_layers,
node_attributes=node_attribute, hidden_units=args.hidden_units,
edge_weights=edge_weight, k=args.sort_k,
node_embedding=None, gcn_type=args.gcn_type,
use_embedding=True, node_attributes=node_attribute,
num_nodes=num_nodes, edge_weights=edge_weight,
dropout=args.dropout) node_embedding=None,
use_embedding=True,
num_nodes=num_nodes,
dropout=args.dropout,
)
else: else:
raise ValueError('Model error') raise ValueError("Model error")
model = model.to(device) model = model.to(device)
parameters = model.parameters() parameters = model.parameters()
optimizer = torch.optim.Adam(parameters, lr=args.lr) optimizer = torch.optim.Adam(parameters, lr=args.lr)
loss_fn = BCEWithLogitsLoss() loss_fn = BCEWithLogitsLoss()
print_fn("Total parameters: {}".format(sum([p.numel() for p in model.parameters()]))) print_fn(
"Total parameters: {}".format(
sum([p.numel() for p in model.parameters()])
)
)
# train and evaluate loop # train and evaluate loop
summary_val = [] summary_val = []
summary_test = [] summary_test = []
for epoch in range(args.epochs): for epoch in range(args.epochs):
start_time = time.time() start_time = time.time()
loss = train(model=model, loss = train(
dataloader=train_loader, model=model,
loss_fn=loss_fn, dataloader=train_loader,
optimizer=optimizer, loss_fn=loss_fn,
device=device, optimizer=optimizer,
num_graphs=args.batch_size, device=device,
total_graphs=train_graphs) num_graphs=args.batch_size,
total_graphs=train_graphs,
)
train_time = time.time() train_time = time.time()
if epoch % args.eval_steps == 0: if epoch % args.eval_steps == 0:
val_pos_pred, val_neg_pred = evaluate(model=model, val_pos_pred, val_neg_pred = evaluate(
dataloader=val_loader, model=model, dataloader=val_loader, device=device
device=device) )
test_pos_pred, test_neg_pred = evaluate(model=model, test_pos_pred, test_neg_pred = evaluate(
dataloader=test_loader, model=model, dataloader=test_loader, device=device
device=device) )
val_metric = evaluate_hits(args.dataset, val_pos_pred, val_neg_pred, args.hits_k) val_metric = evaluate_hits(
test_metric = evaluate_hits(args.dataset, test_pos_pred, test_neg_pred, args.hits_k) args.dataset, val_pos_pred, val_neg_pred, args.hits_k
)
test_metric = evaluate_hits(
args.dataset, test_pos_pred, test_neg_pred, args.hits_k
)
evaluate_time = time.time() evaluate_time = time.time()
print_fn("Epoch-{}, train loss: {:.4f}, hits@{}: val-{:.4f}, test-{:.4f}, " print_fn(
"cost time: train-{:.1f}s, total-{:.1f}s".format(epoch, loss, args.hits_k, val_metric, test_metric, "Epoch-{}, train loss: {:.4f}, hits@{}: val-{:.4f}, test-{:.4f}, "
train_time - start_time, "cost time: train-{:.1f}s, total-{:.1f}s".format(
evaluate_time - start_time)) epoch,
loss,
args.hits_k,
val_metric,
test_metric,
train_time - start_time,
evaluate_time - start_time,
)
)
summary_val.append(val_metric) summary_val.append(val_metric)
summary_test.append(test_metric) summary_test.append(test_metric)
summary_test = np.array(summary_test) summary_test = np.array(summary_test)
print_fn("Experiment Results:") print_fn("Experiment Results:")
print_fn("Best hits@{}: {:.4f}, epoch: {}".format(args.hits_k, np.max(summary_test), np.argmax(summary_test))) print_fn(
"Best hits@{}: {:.4f}, epoch: {}".format(
args.hits_k, np.max(summary_test), np.argmax(summary_test)
)
)
if __name__ == '__main__': if __name__ == "__main__":
args = parse_arguments() args = parse_arguments()
logger = LightLogging(log_name='SEAL', log_path='./logs') logger = LightLogging(log_name="SEAL", log_path="./logs")
main(args, logger.info) main(args, logger.info)
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from dgl.nn.pytorch import SortPooling, SumPooling
from dgl.nn.pytorch import GraphConv, SAGEConv from dgl.nn.pytorch import GraphConv, SAGEConv, SortPooling, SumPooling
class GCN(nn.Module): class GCN(nn.Module):
...@@ -26,9 +26,20 @@ class GCN(nn.Module): ...@@ -26,9 +26,20 @@ class GCN(nn.Module):
""" """
def __init__(self, num_layers, hidden_units, gcn_type='gcn', pooling_type='sum', node_attributes=None, def __init__(
edge_weights=None, node_embedding=None, use_embedding=False, self,
num_nodes=None, dropout=0.5, max_z=1000): num_layers,
hidden_units,
gcn_type="gcn",
pooling_type="sum",
node_attributes=None,
edge_weights=None,
node_embedding=None,
use_embedding=False,
num_nodes=None,
dropout=0.5,
max_z=1000,
):
super(GCN, self).__init__() super(GCN, self).__init__()
self.num_layers = num_layers self.num_layers = num_layers
self.dropout = dropout self.dropout = dropout
...@@ -39,10 +50,14 @@ class GCN(nn.Module): ...@@ -39,10 +50,14 @@ class GCN(nn.Module):
self.z_embedding = nn.Embedding(max_z, hidden_units) self.z_embedding = nn.Embedding(max_z, hidden_units)
if node_attributes is not None: if node_attributes is not None:
self.node_attributes_lookup = nn.Embedding.from_pretrained(node_attributes) self.node_attributes_lookup = nn.Embedding.from_pretrained(
node_attributes
)
self.node_attributes_lookup.weight.requires_grad = False self.node_attributes_lookup.weight.requires_grad = False
if edge_weights is not None: if edge_weights is not None:
self.edge_weights_lookup = nn.Embedding.from_pretrained(edge_weights) self.edge_weights_lookup = nn.Embedding.from_pretrained(
edge_weights
)
self.edge_weights_lookup.weight.requires_grad = False self.edge_weights_lookup.weight.requires_grad = False
if node_embedding is not None: if node_embedding is not None:
self.node_embedding = nn.Embedding.from_pretrained(node_embedding) self.node_embedding = nn.Embedding.from_pretrained(node_embedding)
...@@ -57,21 +72,31 @@ class GCN(nn.Module): ...@@ -57,21 +72,31 @@ class GCN(nn.Module):
initial_dim += self.node_embedding.embedding_dim initial_dim += self.node_embedding.embedding_dim
self.layers = nn.ModuleList() self.layers = nn.ModuleList()
if gcn_type == 'gcn': if gcn_type == "gcn":
self.layers.append(GraphConv(initial_dim, hidden_units, allow_zero_in_degree=True)) self.layers.append(
GraphConv(initial_dim, hidden_units, allow_zero_in_degree=True)
)
for _ in range(num_layers - 1): for _ in range(num_layers - 1):
self.layers.append(GraphConv(hidden_units, hidden_units, allow_zero_in_degree=True)) self.layers.append(
elif gcn_type == 'sage': GraphConv(
self.layers.append(SAGEConv(initial_dim, hidden_units, aggregator_type='gcn')) hidden_units, hidden_units, allow_zero_in_degree=True
)
)
elif gcn_type == "sage":
self.layers.append(
SAGEConv(initial_dim, hidden_units, aggregator_type="gcn")
)
for _ in range(num_layers - 1): for _ in range(num_layers - 1):
self.layers.append(SAGEConv(hidden_units, hidden_units, aggregator_type='gcn')) self.layers.append(
SAGEConv(hidden_units, hidden_units, aggregator_type="gcn")
)
else: else:
raise ValueError('Gcn type error.') raise ValueError("Gcn type error.")
self.linear_1 = nn.Linear(hidden_units, hidden_units) self.linear_1 = nn.Linear(hidden_units, hidden_units)
self.linear_2 = nn.Linear(hidden_units, 1) self.linear_2 = nn.Linear(hidden_units, 1)
if pooling_type != 'sum': if pooling_type != "sum":
raise ValueError('Pooling type error.') raise ValueError("Pooling type error.")
self.pooling = SumPooling() self.pooling = SumPooling()
def reset_parameters(self): def reset_parameters(self):
...@@ -141,8 +166,20 @@ class DGCNN(nn.Module): ...@@ -141,8 +166,20 @@ class DGCNN(nn.Module):
max_z(int, optional): default max vocab size of node labeling, default 1000. max_z(int, optional): default max vocab size of node labeling, default 1000.
""" """
def __init__(self, num_layers, hidden_units, k=10, gcn_type='gcn', node_attributes=None, def __init__(
edge_weights=None, node_embedding=None, use_embedding=False, num_nodes=None, dropout=0.5, max_z=1000): self,
num_layers,
hidden_units,
k=10,
gcn_type="gcn",
node_attributes=None,
edge_weights=None,
node_embedding=None,
use_embedding=False,
num_nodes=None,
dropout=0.5,
max_z=1000,
):
super(DGCNN, self).__init__() super(DGCNN, self).__init__()
self.num_layers = num_layers self.num_layers = num_layers
self.dropout = dropout self.dropout = dropout
...@@ -153,10 +190,14 @@ class DGCNN(nn.Module): ...@@ -153,10 +190,14 @@ class DGCNN(nn.Module):
self.z_embedding = nn.Embedding(max_z, hidden_units) self.z_embedding = nn.Embedding(max_z, hidden_units)
if node_attributes is not None: if node_attributes is not None:
self.node_attributes_lookup = nn.Embedding.from_pretrained(node_attributes) self.node_attributes_lookup = nn.Embedding.from_pretrained(
node_attributes
)
self.node_attributes_lookup.weight.requires_grad = False self.node_attributes_lookup.weight.requires_grad = False
if edge_weights is not None: if edge_weights is not None:
self.edge_weights_lookup = nn.Embedding.from_pretrained(edge_weights) self.edge_weights_lookup = nn.Embedding.from_pretrained(
edge_weights
)
self.edge_weights_lookup.weight.requires_grad = False self.edge_weights_lookup.weight.requires_grad = False
if node_embedding is not None: if node_embedding is not None:
self.node_embedding = nn.Embedding.from_pretrained(node_embedding) self.node_embedding = nn.Embedding.from_pretrained(node_embedding)
...@@ -171,28 +212,42 @@ class DGCNN(nn.Module): ...@@ -171,28 +212,42 @@ class DGCNN(nn.Module):
initial_dim += self.node_embedding.embedding_dim initial_dim += self.node_embedding.embedding_dim
self.layers = nn.ModuleList() self.layers = nn.ModuleList()
if gcn_type == 'gcn': if gcn_type == "gcn":
self.layers.append(GraphConv(initial_dim, hidden_units, allow_zero_in_degree=True)) self.layers.append(
GraphConv(initial_dim, hidden_units, allow_zero_in_degree=True)
)
for _ in range(num_layers - 1): for _ in range(num_layers - 1):
self.layers.append(GraphConv(hidden_units, hidden_units, allow_zero_in_degree=True)) self.layers.append(
self.layers.append(GraphConv(hidden_units, 1, allow_zero_in_degree=True)) GraphConv(
elif gcn_type == 'sage': hidden_units, hidden_units, allow_zero_in_degree=True
self.layers.append(SAGEConv(initial_dim, hidden_units, aggregator_type='gcn')) )
)
self.layers.append(
GraphConv(hidden_units, 1, allow_zero_in_degree=True)
)
elif gcn_type == "sage":
self.layers.append(
SAGEConv(initial_dim, hidden_units, aggregator_type="gcn")
)
for _ in range(num_layers - 1): for _ in range(num_layers - 1):
self.layers.append(SAGEConv(hidden_units, hidden_units, aggregator_type='gcn')) self.layers.append(
self.layers.append(SAGEConv(hidden_units, 1, aggregator_type='gcn')) SAGEConv(hidden_units, hidden_units, aggregator_type="gcn")
)
self.layers.append(SAGEConv(hidden_units, 1, aggregator_type="gcn"))
else: else:
raise ValueError('Gcn type error.') raise ValueError("Gcn type error.")
self.pooling = SortPooling(k=k) self.pooling = SortPooling(k=k)
conv1d_channels = [16, 32] conv1d_channels = [16, 32]
total_latent_dim = hidden_units * num_layers + 1 total_latent_dim = hidden_units * num_layers + 1
conv1d_kws = [total_latent_dim, 5] conv1d_kws = [total_latent_dim, 5]
self.conv_1 = nn.Conv1d(1, conv1d_channels[0], conv1d_kws[0], self.conv_1 = nn.Conv1d(
conv1d_kws[0]) 1, conv1d_channels[0], conv1d_kws[0], conv1d_kws[0]
)
self.maxpool1d = nn.MaxPool1d(2, 2) self.maxpool1d = nn.MaxPool1d(2, 2)
self.conv_2 = nn.Conv1d(conv1d_channels[0], conv1d_channels[1], self.conv_2 = nn.Conv1d(
conv1d_kws[1], 1) conv1d_channels[0], conv1d_channels[1], conv1d_kws[1], 1
)
dense_dim = int((k - 2) / 2 + 1) dense_dim = int((k - 2) / 2 + 1)
dense_dim = (dense_dim - conv1d_kws[1] + 1) * conv1d_channels[1] dense_dim = (dense_dim - conv1d_kws[1] + 1) * conv1d_channels[1]
self.linear_1 = nn.Linear(dense_dim, 128) self.linear_1 = nn.Linear(dense_dim, 128)
......
import os.path as osp import os.path as osp
from tqdm import tqdm
from copy import deepcopy from copy import deepcopy
import torch import torch
import dgl
from torch.utils.data import DataLoader, Dataset from torch.utils.data import DataLoader, Dataset
from dgl import DGLGraph, NID from tqdm import tqdm
from dgl.dataloading.negative_sampler import Uniform
from dgl import add_self_loop
from utils import drnl_node_labeling from utils import drnl_node_labeling
import dgl
from dgl import NID, DGLGraph, add_self_loop
from dgl.dataloading.negative_sampler import Uniform
class GraphDataSet(Dataset): class GraphDataSet(Dataset):
""" """
...@@ -37,7 +38,9 @@ class PosNegEdgesGenerator(object): ...@@ -37,7 +38,9 @@ class PosNegEdgesGenerator(object):
shuffle(bool): if shuffle generated graph list shuffle(bool): if shuffle generated graph list
""" """
def __init__(self, g, split_edge, neg_samples=1, subsample_ratio=0.1, shuffle=True): def __init__(
self, g, split_edge, neg_samples=1, subsample_ratio=0.1, shuffle=True
):
self.neg_sampler = Uniform(neg_samples) self.neg_sampler = Uniform(neg_samples)
self.subsample_ratio = subsample_ratio self.subsample_ratio = subsample_ratio
self.split_edge = split_edge self.split_edge = split_edge
...@@ -46,24 +49,29 @@ class PosNegEdgesGenerator(object): ...@@ -46,24 +49,29 @@ class PosNegEdgesGenerator(object):
def __call__(self, split_type): def __call__(self, split_type):
if split_type == 'train': if split_type == "train":
subsample_ratio = self.subsample_ratio subsample_ratio = self.subsample_ratio
else: else:
subsample_ratio = 1 subsample_ratio = 1
pos_edges = self.split_edge[split_type]['edge'] pos_edges = self.split_edge[split_type]["edge"]
if split_type == 'train': if split_type == "train":
# Adding self loop in train avoids sampling the source node itself. # Adding self loop in train avoids sampling the source node itself.
g = add_self_loop(self.g) g = add_self_loop(self.g)
eids = g.edge_ids(pos_edges[:, 0], pos_edges[:, 1]) eids = g.edge_ids(pos_edges[:, 0], pos_edges[:, 1])
neg_edges = torch.stack(self.neg_sampler(g, eids), dim=1) neg_edges = torch.stack(self.neg_sampler(g, eids), dim=1)
else: else:
neg_edges = self.split_edge[split_type]['edge_neg'] neg_edges = self.split_edge[split_type]["edge_neg"]
pos_edges = self.subsample(pos_edges, subsample_ratio).long() pos_edges = self.subsample(pos_edges, subsample_ratio).long()
neg_edges = self.subsample(neg_edges, subsample_ratio).long() neg_edges = self.subsample(neg_edges, subsample_ratio).long()
edges = torch.cat([pos_edges, neg_edges]) edges = torch.cat([pos_edges, neg_edges])
labels = torch.cat([torch.ones(pos_edges.size(0), 1), torch.zeros(neg_edges.size(0), 1)]) labels = torch.cat(
[
torch.ones(pos_edges.size(0), 1),
torch.zeros(neg_edges.size(0), 1),
]
)
if self.shuffle: if self.shuffle:
perm = torch.randperm(edges.size(0)) perm = torch.randperm(edges.size(0))
edges = edges[perm] edges = edges[perm]
...@@ -84,7 +92,7 @@ class PosNegEdgesGenerator(object): ...@@ -84,7 +92,7 @@ class PosNegEdgesGenerator(object):
num_edges = edges.size(0) num_edges = edges.size(0)
perm = torch.randperm(num_edges) perm = torch.randperm(num_edges)
perm = perm[:int(subsample_ratio * num_edges)] perm = perm[: int(subsample_ratio * num_edges)]
edges = edges[perm] edges = edges[perm]
return edges return edges
...@@ -144,8 +152,16 @@ class SEALSampler(object): ...@@ -144,8 +152,16 @@ class SEALSampler(object):
subgraph = dgl.node_subgraph(self.graph, sample_nodes) subgraph = dgl.node_subgraph(self.graph, sample_nodes)
# Each node should have unique node id in the new subgraph # Each node should have unique node id in the new subgraph
u_id = int(torch.nonzero(subgraph.ndata[NID] == int(target_nodes[0]), as_tuple=False)) u_id = int(
v_id = int(torch.nonzero(subgraph.ndata[NID] == int(target_nodes[1]), as_tuple=False)) torch.nonzero(
subgraph.ndata[NID] == int(target_nodes[0]), as_tuple=False
)
)
v_id = int(
torch.nonzero(
subgraph.ndata[NID] == int(target_nodes[1]), as_tuple=False
)
)
# remove link between target nodes in positive subgraphs. # remove link between target nodes in positive subgraphs.
if subgraph.has_edges_between(u_id, v_id): if subgraph.has_edges_between(u_id, v_id):
...@@ -156,7 +172,7 @@ class SEALSampler(object): ...@@ -156,7 +172,7 @@ class SEALSampler(object):
subgraph.remove_edges(link_id) subgraph.remove_edges(link_id)
z = drnl_node_labeling(subgraph, u_id, v_id) z = drnl_node_labeling(subgraph, u_id, v_id)
subgraph.ndata['z'] = z subgraph.ndata["z"] = z
return subgraph return subgraph
...@@ -171,10 +187,19 @@ class SEALSampler(object): ...@@ -171,10 +187,19 @@ class SEALSampler(object):
def __call__(self, edges, labels): def __call__(self, edges, labels):
subgraph_list = [] subgraph_list = []
labels_list = [] labels_list = []
edge_dataset = EdgeDataSet(edges, labels, transform=self.sample_subgraph) edge_dataset = EdgeDataSet(
self.print_fn('Using {} workers in sampling job.'.format(self.num_workers)) edges, labels, transform=self.sample_subgraph
sampler = DataLoader(edge_dataset, batch_size=32, num_workers=self.num_workers, )
shuffle=False, collate_fn=self._collate) self.print_fn(
"Using {} workers in sampling job.".format(self.num_workers)
)
sampler = DataLoader(
edge_dataset,
batch_size=32,
num_workers=self.num_workers,
shuffle=False,
collate_fn=self._collate,
)
for subgraph, label in tqdm(sampler, ncols=100): for subgraph, label in tqdm(sampler, ncols=100):
label_copy = deepcopy(label) label_copy = deepcopy(label)
subgraph = dgl.unbatch(subgraph) subgraph = dgl.unbatch(subgraph)
...@@ -200,8 +225,20 @@ class SEALData(object): ...@@ -200,8 +225,20 @@ class SEALData(object):
use_coalesce(bool): True for coalesce graph. Graph with multi-edge need to coalesce use_coalesce(bool): True for coalesce graph. Graph with multi-edge need to coalesce
""" """
def __init__(self, g, split_edge, hop=1, neg_samples=1, subsample_ratio=1, prefix=None, save_dir=None, def __init__(
num_workers=32, shuffle=True, use_coalesce=True, print_fn=print): self,
g,
split_edge,
hop=1,
neg_samples=1,
subsample_ratio=1,
prefix=None,
save_dir=None,
num_workers=32,
shuffle=True,
use_coalesce=True,
print_fn=print,
):
self.g = g self.g = g
self.hop = hop self.hop = hop
self.subsample_ratio = subsample_ratio self.subsample_ratio = subsample_ratio
...@@ -209,15 +246,19 @@ class SEALData(object): ...@@ -209,15 +246,19 @@ class SEALData(object):
self.save_dir = save_dir self.save_dir = save_dir
self.print_fn = print_fn self.print_fn = print_fn
self.generator = PosNegEdgesGenerator(g=self.g, self.generator = PosNegEdgesGenerator(
split_edge=split_edge, g=self.g,
neg_samples=neg_samples, split_edge=split_edge,
subsample_ratio=subsample_ratio, neg_samples=neg_samples,
shuffle=shuffle) subsample_ratio=subsample_ratio,
shuffle=shuffle,
)
if use_coalesce: if use_coalesce:
for k, v in g.edata.items(): for k, v in g.edata.items():
g.edata[k] = v.float() # dgl.to_simple() requires data is float g.edata[k] = v.float() # dgl.to_simple() requires data is float
self.g = dgl.to_simple(g, copy_ndata=True, copy_edata=True, aggregator='sum') self.g = dgl.to_simple(
g, copy_ndata=True, copy_edata=True, aggregator="sum"
)
self.ndata = {k: v for k, v in self.g.ndata.items()} self.ndata = {k: v for k, v in self.g.ndata.items()}
self.edata = {k: v for k, v in self.g.edata.items()} self.edata = {k: v for k, v in self.g.edata.items()}
...@@ -226,25 +267,28 @@ class SEALData(object): ...@@ -226,25 +267,28 @@ class SEALData(object):
self.print_fn("Save ndata and edata in class.") self.print_fn("Save ndata and edata in class.")
self.print_fn("Clear ndata and edata in graph.") self.print_fn("Clear ndata and edata in graph.")
self.sampler = SEALSampler(graph=self.g, self.sampler = SEALSampler(
hop=hop, graph=self.g, hop=hop, num_workers=num_workers, print_fn=print_fn
num_workers=num_workers, )
print_fn=print_fn)
def __call__(self, split_type): def __call__(self, split_type):
if split_type == 'train': if split_type == "train":
subsample_ratio = self.subsample_ratio subsample_ratio = self.subsample_ratio
else: else:
subsample_ratio = 1 subsample_ratio = 1
path = osp.join(self.save_dir or '', '{}_{}_{}-hop_{}-subsample.bin'.format(self.prefix, split_type, path = osp.join(
self.hop, subsample_ratio)) self.save_dir or "",
"{}_{}_{}-hop_{}-subsample.bin".format(
self.prefix, split_type, self.hop, subsample_ratio
),
)
if osp.exists(path): if osp.exists(path):
self.print_fn("Load existing processed {} files".format(split_type)) self.print_fn("Load existing processed {} files".format(split_type))
graph_list, data = dgl.load_graphs(path) graph_list, data = dgl.load_graphs(path)
dataset = GraphDataSet(graph_list, data['labels']) dataset = GraphDataSet(graph_list, data["labels"])
else: else:
self.print_fn("Processed {} files not exist.".format(split_type)) self.print_fn("Processed {} files not exist.".format(split_type))
...@@ -254,6 +298,6 @@ class SEALData(object): ...@@ -254,6 +298,6 @@ class SEALData(object):
graph_list, labels = self.sampler(edges, labels) graph_list, labels = self.sampler(edges, labels)
dataset = GraphDataSet(graph_list, labels) dataset = GraphDataSet(graph_list, labels)
dgl.save_graphs(path, graph_list, {'labels': labels}) dgl.save_graphs(path, graph_list, {"labels": labels})
self.print_fn("Save preprocessed subgraph to {}".format(path)) self.print_fn("Save preprocessed subgraph to {}".format(path))
return dataset return dataset
import argparse import argparse
from scipy.sparse.csgraph import shortest_path
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import torch import torch
import dgl
from ogb.linkproppred import DglLinkPropPredDataset, Evaluator from ogb.linkproppred import DglLinkPropPredDataset, Evaluator
from scipy.sparse.csgraph import shortest_path
import dgl
def parse_arguments(): def parse_arguments():
""" """
Parse arguments Parse arguments
""" """
parser = argparse.ArgumentParser(description='SEAL') parser = argparse.ArgumentParser(description="SEAL")
parser.add_argument('--dataset', type=str, default='ogbl-collab') parser.add_argument("--dataset", type=str, default="ogbl-collab")
parser.add_argument('--gpu_id', type=int, default=0) parser.add_argument("--gpu_id", type=int, default=0)
parser.add_argument('--hop', type=int, default=1) parser.add_argument("--hop", type=int, default=1)
parser.add_argument('--model', type=str, default='dgcnn') parser.add_argument("--model", type=str, default="dgcnn")
parser.add_argument('--gcn_type', type=str, default='gcn') parser.add_argument("--gcn_type", type=str, default="gcn")
parser.add_argument('--num_layers', type=int, default=3) parser.add_argument("--num_layers", type=int, default=3)
parser.add_argument('--hidden_units', type=int, default=32) parser.add_argument("--hidden_units", type=int, default=32)
parser.add_argument('--sort_k', type=int, default=30) parser.add_argument("--sort_k", type=int, default=30)
parser.add_argument('--pooling', type=str, default='sum') parser.add_argument("--pooling", type=str, default="sum")
parser.add_argument('--dropout', type=str, default=0.5) parser.add_argument("--dropout", type=str, default=0.5)
parser.add_argument('--hits_k', type=int, default=50) parser.add_argument("--hits_k", type=int, default=50)
parser.add_argument('--lr', type=float, default=0.0001) parser.add_argument("--lr", type=float, default=0.0001)
parser.add_argument('--neg_samples', type=int, default=1) parser.add_argument("--neg_samples", type=int, default=1)
parser.add_argument('--subsample_ratio', type=float, default=0.1) parser.add_argument("--subsample_ratio", type=float, default=0.1)
parser.add_argument('--epochs', type=int, default=60) parser.add_argument("--epochs", type=int, default=60)
parser.add_argument('--batch_size', type=int, default=32) parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument('--eval_steps', type=int, default=5) parser.add_argument("--eval_steps", type=int, default=5)
parser.add_argument('--num_workers', type=int, default=32) parser.add_argument("--num_workers", type=int, default=32)
parser.add_argument('--random_seed', type=int, default=2021) parser.add_argument("--random_seed", type=int, default=2021)
parser.add_argument('--save_dir', type=str, default='./processed') parser.add_argument("--save_dir", type=str, default="./processed")
args = parser.parse_args() args = parser.parse_args()
return args return args
...@@ -79,11 +81,15 @@ def drnl_node_labeling(subgraph, src, dst): ...@@ -79,11 +81,15 @@ def drnl_node_labeling(subgraph, src, dst):
idx = list(range(dst)) + list(range(dst + 1, adj.shape[0])) idx = list(range(dst)) + list(range(dst + 1, adj.shape[0]))
adj_wo_dst = adj[idx, :][:, idx] adj_wo_dst = adj[idx, :][:, idx]
dist2src = shortest_path(adj_wo_dst, directed=False, unweighted=True, indices=src) dist2src = shortest_path(
adj_wo_dst, directed=False, unweighted=True, indices=src
)
dist2src = np.insert(dist2src, dst, 0, axis=0) dist2src = np.insert(dist2src, dst, 0, axis=0)
dist2src = torch.from_numpy(dist2src) dist2src = torch.from_numpy(dist2src)
dist2dst = shortest_path(adj_wo_src, directed=False, unweighted=True, indices=dst - 1) dist2dst = shortest_path(
adj_wo_src, directed=False, unweighted=True, indices=dst - 1
)
dist2dst = np.insert(dist2dst, src, 0, axis=0) dist2dst = np.insert(dist2dst, src, 0, axis=0)
dist2dst = torch.from_numpy(dist2dst) dist2dst = torch.from_numpy(dist2dst)
...@@ -92,9 +98,9 @@ def drnl_node_labeling(subgraph, src, dst): ...@@ -92,9 +98,9 @@ def drnl_node_labeling(subgraph, src, dst):
z = 1 + torch.min(dist2src, dist2dst) z = 1 + torch.min(dist2src, dist2dst)
z += dist_over_2 * (dist_over_2 + dist_mod_2 - 1) z += dist_over_2 * (dist_over_2 + dist_mod_2 - 1)
z[src] = 1. z[src] = 1.0
z[dst] = 1. z[dst] = 1.0
z[torch.isnan(z)] = 0. z[torch.isnan(z)] = 0.0
return z.to(torch.long) return z.to(torch.long)
...@@ -115,9 +121,11 @@ def evaluate_hits(name, pos_pred, neg_pred, K): ...@@ -115,9 +121,11 @@ def evaluate_hits(name, pos_pred, neg_pred, K):
""" """
evaluator = Evaluator(name) evaluator = Evaluator(name)
evaluator.K = K evaluator.K = K
hits = evaluator.eval({ hits = evaluator.eval(
'y_pred_pos': pos_pred, {
'y_pred_neg': neg_pred, "y_pred_pos": pos_pred,
})[f'hits@{K}'] "y_pred_neg": neg_pred,
}
)[f"hits@{K}"]
return hits return hits
...@@ -5,37 +5,46 @@ Paper: https://arxiv.org/abs/1902.07153 ...@@ -5,37 +5,46 @@ Paper: https://arxiv.org/abs/1902.07153
Code: https://github.com/Tiiiger/SGC Code: https://github.com/Tiiiger/SGC
SGC implementation in DGL. SGC implementation in DGL.
""" """
import argparse, time, math import argparse
import math
import time
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import dgl.function as fn
import dgl import dgl
from dgl.data import register_data_args import dgl.function as fn
from dgl.data import CoraGraphDataset, CiteseerGraphDataset, PubmedGraphDataset from dgl.data import (
CiteseerGraphDataset,
CoraGraphDataset,
PubmedGraphDataset,
register_data_args,
)
from dgl.nn.pytorch.conv import SGConv from dgl.nn.pytorch.conv import SGConv
def evaluate(model, g, features, labels, mask): def evaluate(model, g, features, labels, mask):
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
logits = model(g, features)[mask] # only compute the evaluation set logits = model(g, features)[mask] # only compute the evaluation set
labels = labels[mask] labels = labels[mask]
_, indices = torch.max(logits, dim=1) _, indices = torch.max(logits, dim=1)
correct = torch.sum(indices == labels) correct = torch.sum(indices == labels)
return correct.item() * 1.0 / len(labels) return correct.item() * 1.0 / len(labels)
def main(args): def main(args):
# load and preprocess dataset # load and preprocess dataset
if args.dataset == 'cora': if args.dataset == "cora":
data = CoraGraphDataset() data = CoraGraphDataset()
elif args.dataset == 'citeseer': elif args.dataset == "citeseer":
data = CiteseerGraphDataset() data = CiteseerGraphDataset()
elif args.dataset == 'pubmed': elif args.dataset == "pubmed":
data = PubmedGraphDataset() data = PubmedGraphDataset()
else: else:
raise ValueError('Unknown dataset: {}'.format(args.dataset)) raise ValueError("Unknown dataset: {}".format(args.dataset))
g = data[0] g = data[0]
if args.gpu < 0: if args.gpu < 0:
...@@ -44,24 +53,29 @@ def main(args): ...@@ -44,24 +53,29 @@ def main(args):
cuda = True cuda = True
g = g.int().to(args.gpu) g = g.int().to(args.gpu)
features = g.ndata['feat'] features = g.ndata["feat"]
labels = g.ndata['label'] labels = g.ndata["label"]
train_mask = g.ndata['train_mask'] train_mask = g.ndata["train_mask"]
val_mask = g.ndata['val_mask'] val_mask = g.ndata["val_mask"]
test_mask = g.ndata['test_mask'] test_mask = g.ndata["test_mask"]
in_feats = features.shape[1] in_feats = features.shape[1]
n_classes = data.num_labels n_classes = data.num_labels
n_edges = g.number_of_edges() n_edges = g.number_of_edges()
print("""----Data statistics------' print(
"""----Data statistics------'
#Edges %d #Edges %d
#Classes %d #Classes %d
#Train samples %d #Train samples %d
#Val samples %d #Val samples %d
#Test samples %d""" % #Test samples %d"""
(n_edges, n_classes, % (
train_mask.int().sum().item(), n_edges,
val_mask.int().sum().item(), n_classes,
test_mask.int().sum().item())) train_mask.int().sum().item(),
val_mask.int().sum().item(),
test_mask.int().sum().item(),
)
)
n_edges = g.number_of_edges() n_edges = g.number_of_edges()
# add self loop # add self loop
...@@ -69,20 +83,16 @@ def main(args): ...@@ -69,20 +83,16 @@ def main(args):
g = dgl.add_self_loop(g) g = dgl.add_self_loop(g)
# create SGC model # create SGC model
model = SGConv(in_feats, model = SGConv(in_feats, n_classes, k=2, cached=True, bias=args.bias)
n_classes,
k=2,
cached=True,
bias=args.bias)
if cuda: if cuda:
model.cuda() model.cuda()
loss_fcn = torch.nn.CrossEntropyLoss() loss_fcn = torch.nn.CrossEntropyLoss()
# use optimizer # use optimizer
optimizer = torch.optim.Adam(model.parameters(), optimizer = torch.optim.Adam(
lr=args.lr, model.parameters(), lr=args.lr, weight_decay=args.weight_decay
weight_decay=args.weight_decay) )
# initialize graph # initialize graph
dur = [] dur = []
...@@ -91,7 +101,7 @@ def main(args): ...@@ -91,7 +101,7 @@ def main(args):
if epoch >= 3: if epoch >= 3:
t0 = time.time() t0 = time.time()
# forward # forward
logits = model(g, features) # only compute the train set logits = model(g, features) # only compute the train set
loss = loss_fcn(logits[train_mask], labels[train_mask]) loss = loss_fcn(logits[train_mask], labels[train_mask])
optimizer.zero_grad() optimizer.zero_grad()
...@@ -102,28 +112,36 @@ def main(args): ...@@ -102,28 +112,36 @@ def main(args):
dur.append(time.time() - t0) dur.append(time.time() - t0)
acc = evaluate(model, g, features, labels, val_mask) acc = evaluate(model, g, features, labels, val_mask)
print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | " print(
"ETputs(KTEPS) {:.2f}". format(epoch, np.mean(dur), loss.item(), "Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | "
acc, n_edges / np.mean(dur) / 1000)) "ETputs(KTEPS) {:.2f}".format(
epoch,
np.mean(dur),
loss.item(),
acc,
n_edges / np.mean(dur) / 1000,
)
)
print() print()
acc = evaluate(model, g, features, labels, test_mask) acc = evaluate(model, g, features, labels, test_mask)
print("Test Accuracy {:.4f}".format(acc)) print("Test Accuracy {:.4f}".format(acc))
if __name__ == '__main__': if __name__ == "__main__":
parser = argparse.ArgumentParser(description='SGC') parser = argparse.ArgumentParser(description="SGC")
register_data_args(parser) register_data_args(parser)
parser.add_argument("--gpu", type=int, default=-1, parser.add_argument("--gpu", type=int, default=-1, help="gpu")
help="gpu") parser.add_argument("--lr", type=float, default=0.2, help="learning rate")
parser.add_argument("--lr", type=float, default=0.2, parser.add_argument(
help="learning rate") "--bias", action="store_true", default=False, help="flag to use bias"
parser.add_argument("--bias", action='store_true', default=False, )
help="flag to use bias") parser.add_argument(
parser.add_argument("--n-epochs", type=int, default=100, "--n-epochs", type=int, default=100, help="number of training epochs"
help="number of training epochs") )
parser.add_argument("--weight-decay", type=float, default=5e-6, parser.add_argument(
help="Weight for L2 loss") "--weight-decay", type=float, default=5e-6, help="Weight for L2 loss"
)
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment