Unverified Commit be8763fa authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Black auto fix. (#4679)


Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent eae6ce2a
......@@ -3,24 +3,26 @@ Paper: https://arxiv.org/abs/1703.06103
Reference Code: https://github.com/tkipf/relational-gcn
"""
import argparse
import numpy as np
import time
import numpy as np
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from dgl.data.rdf import AIFBDataset, MUTAGDataset, BGSDataset, AMDataset
from model import EntityClassify_HeteroAPI
from dgl.data.rdf import AIFBDataset, AMDataset, BGSDataset, MUTAGDataset
def main(args):
# load graph data
if args.dataset == 'aifb':
if args.dataset == "aifb":
dataset = AIFBDataset()
elif args.dataset == 'mutag':
elif args.dataset == "mutag":
dataset = MUTAGDataset()
elif args.dataset == 'bgs':
elif args.dataset == "bgs":
dataset = BGSDataset()
elif args.dataset == 'am':
elif args.dataset == "am":
dataset = AMDataset()
else:
raise ValueError()
......@@ -28,11 +30,11 @@ def main(args):
g = dataset[0]
category = dataset.predict_category
num_classes = dataset.num_classes
train_mask = g.nodes[category].data.pop('train_mask')
test_mask = g.nodes[category].data.pop('test_mask')
train_mask = g.nodes[category].data.pop("train_mask")
test_mask = g.nodes[category].data.pop("test_mask")
train_idx = th.nonzero(train_mask, as_tuple=False).squeeze()
test_idx = th.nonzero(test_mask, as_tuple=False).squeeze()
labels = g.nodes[category].data.pop('labels')
labels = g.nodes[category].data.pop("labels")
category_id = len(g.ntypes)
for i, ntype in enumerate(g.ntypes):
if ntype == category:
......@@ -40,8 +42,8 @@ def main(args):
# split dataset into train, validate, test
if args.validation:
val_idx = train_idx[:len(train_idx) // 5]
train_idx = train_idx[len(train_idx) // 5:]
val_idx = train_idx[: len(train_idx) // 5]
train_idx = train_idx[len(train_idx) // 5 :]
else:
val_idx = train_idx
......@@ -49,25 +51,29 @@ def main(args):
use_cuda = args.gpu >= 0 and th.cuda.is_available()
if use_cuda:
th.cuda.set_device(args.gpu)
g = g.to('cuda:%d' % args.gpu)
g = g.to("cuda:%d" % args.gpu)
labels = labels.cuda()
train_idx = train_idx.cuda()
test_idx = test_idx.cuda()
# create model
model = EntityClassify_HeteroAPI(g,
args.n_hidden,
num_classes,
num_bases=args.n_bases,
num_hidden_layers=args.n_layers - 2,
dropout=args.dropout,
use_self_loop=args.use_self_loop)
model = EntityClassify_HeteroAPI(
g,
args.n_hidden,
num_classes,
num_bases=args.n_bases,
num_hidden_layers=args.n_layers - 2,
dropout=args.dropout,
use_self_loop=args.use_self_loop,
)
if use_cuda:
model.cuda()
# optimizer
optimizer = th.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.l2norm)
optimizer = th.optim.Adam(
model.parameters(), lr=args.lr, weight_decay=args.l2norm
)
# training loop
print("start training...")
......@@ -83,11 +89,23 @@ def main(args):
t1 = time.time()
dur.append(t1 - t0)
train_acc = th.sum(logits[train_idx].argmax(dim=1) == labels[train_idx]).item() / len(train_idx)
train_acc = th.sum(
logits[train_idx].argmax(dim=1) == labels[train_idx]
).item() / len(train_idx)
val_loss = F.cross_entropy(logits[val_idx], labels[val_idx])
val_acc = th.sum(logits[val_idx].argmax(dim=1) == labels[val_idx]).item() / len(val_idx)
print("Epoch {:05d} | Train Acc: {:.4f} | Train Loss: {:.4f} | Valid Acc: {:.4f} | Valid loss: {:.4f} | Time: {:.4f}".
format(epoch, train_acc, loss.item(), val_acc, val_loss.item(), np.average(dur)))
val_acc = th.sum(
logits[val_idx].argmax(dim=1) == labels[val_idx]
).item() / len(val_idx)
print(
"Epoch {:05d} | Train Acc: {:.4f} | Train Loss: {:.4f} | Valid Acc: {:.4f} | Valid loss: {:.4f} | Time: {:.4f}".format(
epoch,
train_acc,
loss.item(),
val_acc,
val_loss.item(),
np.average(dur),
)
)
print()
if args.model_path is not None:
th.save(model.state_dict(), args.model_path)
......@@ -95,37 +113,59 @@ def main(args):
model.eval()
logits = model.forward()[category]
test_loss = F.cross_entropy(logits[test_idx], labels[test_idx])
test_acc = th.sum(logits[test_idx].argmax(dim=1) == labels[test_idx]).item() / len(test_idx)
print("Test Acc: {:.4f} | Test loss: {:.4f}".format(test_acc, test_loss.item()))
test_acc = th.sum(
logits[test_idx].argmax(dim=1) == labels[test_idx]
).item() / len(test_idx)
print(
"Test Acc: {:.4f} | Test loss: {:.4f}".format(
test_acc, test_loss.item()
)
)
print()
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='RGCN')
parser.add_argument("--dropout", type=float, default=0,
help="dropout probability")
parser.add_argument("--n-hidden", type=int, default=16,
help="number of hidden units")
parser.add_argument("--gpu", type=int, default=-1,
help="gpu")
parser.add_argument("--lr", type=float, default=1e-2,
help="learning rate")
parser.add_argument("--n-bases", type=int, default=-1,
help="number of filter weight matrices, default: -1 [use all]")
parser.add_argument("--n-layers", type=int, default=2,
help="number of propagation rounds")
parser.add_argument("-e", "--n-epochs", type=int, default=50,
help="number of training epochs")
parser.add_argument("-d", "--dataset", type=str, required=True,
help="dataset to use")
parser.add_argument("--model_path", type=str, default=None,
help='path for save the model')
parser.add_argument("--l2norm", type=float, default=0,
help="l2 norm coef")
parser.add_argument("--use-self-loop", default=False, action='store_true',
help="include self feature as a special relation")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="RGCN")
parser.add_argument(
"--dropout", type=float, default=0, help="dropout probability"
)
parser.add_argument(
"--n-hidden", type=int, default=16, help="number of hidden units"
)
parser.add_argument("--gpu", type=int, default=-1, help="gpu")
parser.add_argument("--lr", type=float, default=1e-2, help="learning rate")
parser.add_argument(
"--n-bases",
type=int,
default=-1,
help="number of filter weight matrices, default: -1 [use all]",
)
parser.add_argument(
"--n-layers", type=int, default=2, help="number of propagation rounds"
)
parser.add_argument(
"-e",
"--n-epochs",
type=int,
default=50,
help="number of training epochs",
)
parser.add_argument(
"-d", "--dataset", type=str, required=True, help="dataset to use"
)
parser.add_argument(
"--model_path", type=str, default=None, help="path for save the model"
)
parser.add_argument("--l2norm", type=float, default=0, help="l2 norm coef")
parser.add_argument(
"--use-self-loop",
default=False,
action="store_true",
help="include self feature as a special relation",
)
fp = parser.add_mutually_exclusive_group(required=False)
fp.add_argument('--validation', dest='validation', action='store_true')
fp.add_argument('--testing', dest='validation', action='store_false')
fp.add_argument("--validation", dest="validation", action="store_true")
fp.add_argument("--testing", dest="validation", action="store_false")
parser.set_defaults(validation=True)
args = parser.parse_args()
......
......@@ -4,14 +4,16 @@ Reference Code: https://github.com/tkipf/relational-gcn
"""
import argparse
import itertools
import numpy as np
import time
import numpy as np
import torch as th
import torch.nn.functional as F
from model import EntityClassify, RelGraphEmbed
import dgl
from dgl.data.rdf import AIFBDataset, MUTAGDataset, BGSDataset, AMDataset
from model import EntityClassify, RelGraphEmbed
from dgl.data.rdf import AIFBDataset, AMDataset, BGSDataset, MUTAGDataset
def extract_embed(node_embed, input_nodes):
emb = {}
......@@ -20,6 +22,7 @@ def extract_embed(node_embed, input_nodes):
emb[ntype] = node_embed[ntype][nid]
return emb
def evaluate(model, loader, node_embed, labels, category, device):
model.eval()
total_loss = 0
......@@ -39,22 +42,23 @@ def evaluate(model, loader, node_embed, labels, category, device):
count += len(seeds)
return total_loss / count, total_acc / count
def main(args):
# check cuda
device = 'cpu'
device = "cpu"
use_cuda = args.gpu >= 0 and th.cuda.is_available()
if use_cuda:
th.cuda.set_device(args.gpu)
device = 'cuda:%d' % args.gpu
device = "cuda:%d" % args.gpu
# load graph data
if args.dataset == 'aifb':
if args.dataset == "aifb":
dataset = AIFBDataset()
elif args.dataset == 'mutag':
elif args.dataset == "mutag":
dataset = MUTAGDataset()
elif args.dataset == 'bgs':
elif args.dataset == "bgs":
dataset = BGSDataset()
elif args.dataset == 'am':
elif args.dataset == "am":
dataset = AMDataset()
else:
raise ValueError()
......@@ -62,16 +66,16 @@ def main(args):
g = dataset[0]
category = dataset.predict_category
num_classes = dataset.num_classes
train_mask = g.nodes[category].data.pop('train_mask')
test_mask = g.nodes[category].data.pop('test_mask')
train_mask = g.nodes[category].data.pop("train_mask")
test_mask = g.nodes[category].data.pop("test_mask")
train_idx = th.nonzero(train_mask, as_tuple=False).squeeze()
test_idx = th.nonzero(test_mask, as_tuple=False).squeeze()
labels = g.nodes[category].data.pop('labels')
labels = g.nodes[category].data.pop("labels")
# split dataset into train, validate, test
if args.validation:
val_idx = train_idx[:len(train_idx) // 5]
train_idx = train_idx[len(train_idx) // 5:]
val_idx = train_idx[: len(train_idx) // 5]
train_idx = train_idx[len(train_idx) // 5 :]
else:
val_idx = train_idx
......@@ -84,29 +88,45 @@ def main(args):
node_embed = embed_layer()
# create model
model = EntityClassify(g,
args.n_hidden,
num_classes,
num_bases=args.n_bases,
num_hidden_layers=args.n_layers - 2,
dropout=args.dropout,
use_self_loop=args.use_self_loop)
model = EntityClassify(
g,
args.n_hidden,
num_classes,
num_bases=args.n_bases,
num_hidden_layers=args.n_layers - 2,
dropout=args.dropout,
use_self_loop=args.use_self_loop,
)
if use_cuda:
model.cuda()
# train sampler
sampler = dgl.dataloading.MultiLayerNeighborSampler([args.fanout] * args.n_layers)
sampler = dgl.dataloading.MultiLayerNeighborSampler(
[args.fanout] * args.n_layers
)
loader = dgl.dataloading.DataLoader(
g, {category: train_idx}, sampler,
batch_size=args.batch_size, shuffle=True, num_workers=0)
g,
{category: train_idx},
sampler,
batch_size=args.batch_size,
shuffle=True,
num_workers=0,
)
# validation sampler
# we do not use full neighbor to save computation resources
val_sampler = dgl.dataloading.MultiLayerNeighborSampler([args.fanout] * args.n_layers)
val_sampler = dgl.dataloading.MultiLayerNeighborSampler(
[args.fanout] * args.n_layers
)
val_loader = dgl.dataloading.DataLoader(
g, {category: val_idx}, val_sampler,
batch_size=args.batch_size, shuffle=True, num_workers=0)
g,
{category: val_idx},
val_sampler,
batch_size=args.batch_size,
shuffle=True,
num_workers=0,
)
# optimizer
all_params = itertools.chain(model.parameters(), embed_layer.parameters())
......@@ -123,12 +143,14 @@ def main(args):
for i, (input_nodes, seeds, blocks) in enumerate(loader):
blocks = [blk.to(device) for blk in blocks]
seeds = seeds[category] # we only predict the nodes with type "category"
seeds = seeds[
category
] # we only predict the nodes with type "category"
batch_tic = time.time()
emb = extract_embed(node_embed, input_nodes)
lbl = labels[seeds]
if use_cuda:
emb = {k : e.cuda() for k, e in emb.items()}
emb = {k: e.cuda() for k, e in emb.items()}
lbl = lbl.cuda()
logits = model(emb, blocks)[category]
loss = F.cross_entropy(logits, lbl)
......@@ -136,63 +158,96 @@ def main(args):
optimizer.step()
train_acc = th.sum(logits.argmax(dim=1) == lbl).item() / len(seeds)
print("Epoch {:05d} | Batch {:03d} | Train Acc: {:.4f} | Train Loss: {:.4f} | Time: {:.4f}".
format(epoch, i, train_acc, loss.item(), time.time() - batch_tic))
print(
"Epoch {:05d} | Batch {:03d} | Train Acc: {:.4f} | Train Loss: {:.4f} | Time: {:.4f}".format(
epoch, i, train_acc, loss.item(), time.time() - batch_tic
)
)
if epoch > 3:
dur.append(time.time() - t0)
val_loss, val_acc = evaluate(model, val_loader, node_embed, labels, category, device)
print("Epoch {:05d} | Valid Acc: {:.4f} | Valid loss: {:.4f} | Time: {:.4f}".
format(epoch, val_acc, val_loss, np.average(dur)))
val_loss, val_acc = evaluate(
model, val_loader, node_embed, labels, category, device
)
print(
"Epoch {:05d} | Valid Acc: {:.4f} | Valid loss: {:.4f} | Time: {:.4f}".format(
epoch, val_acc, val_loss, np.average(dur)
)
)
print()
if args.model_path is not None:
th.save(model.state_dict(), args.model_path)
output = model.inference(
g, args.batch_size, 'cuda' if use_cuda else 'cpu', 0, node_embed)
g, args.batch_size, "cuda" if use_cuda else "cpu", 0, node_embed
)
test_pred = output[category][test_idx]
test_labels = labels[test_idx].to(test_pred.device)
test_acc = (test_pred.argmax(1) == test_labels).float().mean()
print("Test Acc: {:.4f}".format(test_acc))
print()
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='RGCN')
parser.add_argument("--dropout", type=float, default=0,
help="dropout probability")
parser.add_argument("--n-hidden", type=int, default=16,
help="number of hidden units")
parser.add_argument("--gpu", type=int, default=-1,
help="gpu")
parser.add_argument("--lr", type=float, default=1e-2,
help="learning rate")
parser.add_argument("--n-bases", type=int, default=-1,
help="number of filter weight matrices, default: -1 [use all]")
parser.add_argument("--n-layers", type=int, default=2,
help="number of propagation rounds")
parser.add_argument("-e", "--n-epochs", type=int, default=20,
help="number of training epochs")
parser.add_argument("-d", "--dataset", type=str, required=True,
help="dataset to use")
parser.add_argument("--model_path", type=str, default=None,
help='path for save the model')
parser.add_argument("--l2norm", type=float, default=0,
help="l2 norm coef")
parser.add_argument("--use-self-loop", default=False, action='store_true',
help="include self feature as a special relation")
parser.add_argument("--batch-size", type=int, default=100,
help="Mini-batch size. If -1, use full graph training.")
parser.add_argument("--fanout", type=int, default=4,
help="Fan-out of neighbor sampling.")
parser.add_argument('--data-cpu', action='store_true',
help="By default the script puts all node features and labels "
"on GPU when using it to save time for data copy. This may "
"be undesired if they cannot fit in GPU memory at once. "
"This flag disables that.")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="RGCN")
parser.add_argument(
"--dropout", type=float, default=0, help="dropout probability"
)
parser.add_argument(
"--n-hidden", type=int, default=16, help="number of hidden units"
)
parser.add_argument("--gpu", type=int, default=-1, help="gpu")
parser.add_argument("--lr", type=float, default=1e-2, help="learning rate")
parser.add_argument(
"--n-bases",
type=int,
default=-1,
help="number of filter weight matrices, default: -1 [use all]",
)
parser.add_argument(
"--n-layers", type=int, default=2, help="number of propagation rounds"
)
parser.add_argument(
"-e",
"--n-epochs",
type=int,
default=20,
help="number of training epochs",
)
parser.add_argument(
"-d", "--dataset", type=str, required=True, help="dataset to use"
)
parser.add_argument(
"--model_path", type=str, default=None, help="path for save the model"
)
parser.add_argument("--l2norm", type=float, default=0, help="l2 norm coef")
parser.add_argument(
"--use-self-loop",
default=False,
action="store_true",
help="include self feature as a special relation",
)
parser.add_argument(
"--batch-size",
type=int,
default=100,
help="Mini-batch size. If -1, use full graph training.",
)
parser.add_argument(
"--fanout", type=int, default=4, help="Fan-out of neighbor sampling."
)
parser.add_argument(
"--data-cpu",
action="store_true",
help="By default the script puts all node features and labels "
"on GPU when using it to save time for data copy. This may "
"be undesired if they cannot fit in GPU memory at once. "
"This flag disables that.",
)
fp = parser.add_mutually_exclusive_group(required=False)
fp.add_argument('--validation', dest='validation', action='store_true')
fp.add_argument('--testing', dest='validation', action='store_false')
fp.add_argument("--validation", dest="validation", action="store_true")
fp.add_argument("--testing", dest="validation", action="store_false")
parser.set_defaults(validation=True)
args = parser.parse_args()
......
......@@ -4,10 +4,12 @@ from collections import defaultdict
import torch as th
import torch.nn as nn
import torch.nn.functional as F
import tqdm
import dgl
import dgl.function as fn
import dgl.nn as dglnn
import tqdm
class RelGraphConvLayer(nn.Module):
r"""Relational graph convolution layer.
......@@ -33,17 +35,20 @@ class RelGraphConvLayer(nn.Module):
dropout : float, optional
Dropout rate. Default: 0.0
"""
def __init__(self,
in_feat,
out_feat,
rel_names,
num_bases,
*,
weight=True,
bias=True,
activation=None,
self_loop=False,
dropout=0.0):
def __init__(
self,
in_feat,
out_feat,
rel_names,
num_bases,
*,
weight=True,
bias=True,
activation=None,
self_loop=False,
dropout=0.0
):
super(RelGraphConvLayer, self).__init__()
self.in_feat = in_feat
self.out_feat = out_feat
......@@ -53,19 +58,29 @@ class RelGraphConvLayer(nn.Module):
self.activation = activation
self.self_loop = self_loop
self.conv = dglnn.HeteroGraphConv({
rel : dglnn.GraphConv(in_feat, out_feat, norm='right', weight=False, bias=False)
self.conv = dglnn.HeteroGraphConv(
{
rel: dglnn.GraphConv(
in_feat, out_feat, norm="right", weight=False, bias=False
)
for rel in rel_names
})
}
)
self.use_weight = weight
self.use_basis = num_bases < len(self.rel_names) and weight
if self.use_weight:
if self.use_basis:
self.basis = dglnn.WeightBasis((in_feat, out_feat), num_bases, len(self.rel_names))
self.basis = dglnn.WeightBasis(
(in_feat, out_feat), num_bases, len(self.rel_names)
)
else:
self.weight = nn.Parameter(th.Tensor(len(self.rel_names), in_feat, out_feat))
nn.init.xavier_uniform_(self.weight, gain=nn.init.calculate_gain('relu'))
self.weight = nn.Parameter(
th.Tensor(len(self.rel_names), in_feat, out_feat)
)
nn.init.xavier_uniform_(
self.weight, gain=nn.init.calculate_gain("relu")
)
# bias
if bias:
......@@ -75,8 +90,9 @@ class RelGraphConvLayer(nn.Module):
# weight for self loop
if self.self_loop:
self.loop_weight = nn.Parameter(th.Tensor(in_feat, out_feat))
nn.init.xavier_uniform_(self.loop_weight,
gain=nn.init.calculate_gain('relu'))
nn.init.xavier_uniform_(
self.loop_weight, gain=nn.init.calculate_gain("relu")
)
self.dropout = nn.Dropout(dropout)
......@@ -98,14 +114,18 @@ class RelGraphConvLayer(nn.Module):
g = g.local_var()
if self.use_weight:
weight = self.basis() if self.use_basis else self.weight
wdict = {self.rel_names[i] : {'weight' : w.squeeze(0)}
for i, w in enumerate(th.split(weight, 1, dim=0))}
wdict = {
self.rel_names[i]: {"weight": w.squeeze(0)}
for i, w in enumerate(th.split(weight, 1, dim=0))
}
else:
wdict = {}
if g.is_block:
inputs_src = inputs
inputs_dst = {k: v[:g.number_of_dst_nodes(k)] for k, v in inputs.items()}
inputs_dst = {
k: v[: g.number_of_dst_nodes(k)] for k, v in inputs.items()
}
else:
inputs_src = inputs_dst = inputs
......@@ -119,7 +139,9 @@ class RelGraphConvLayer(nn.Module):
if self.activation:
h = self.activation(h)
return self.dropout(h)
return {ntype : _apply(ntype, h) for ntype, h in hs.items()}
return {ntype: _apply(ntype, h) for ntype, h in hs.items()}
class RelGraphConvLayerHeteroAPI(nn.Module):
r"""Relational graph convolution layer.
......@@ -145,17 +167,20 @@ class RelGraphConvLayerHeteroAPI(nn.Module):
dropout : float, optional
Dropout rate. Default: 0.0
"""
def __init__(self,
in_feat,
out_feat,
rel_names,
num_bases,
*,
weight=True,
bias=True,
activation=None,
self_loop=False,
dropout=0.0):
def __init__(
self,
in_feat,
out_feat,
rel_names,
num_bases,
*,
weight=True,
bias=True,
activation=None,
self_loop=False,
dropout=0.0
):
super(RelGraphConvLayerHeteroAPI, self).__init__()
self.in_feat = in_feat
self.out_feat = out_feat
......@@ -169,10 +194,16 @@ class RelGraphConvLayerHeteroAPI(nn.Module):
self.use_basis = num_bases < len(self.rel_names) and weight
if self.use_weight:
if self.use_basis:
self.basis = dglnn.WeightBasis((in_feat, out_feat), num_bases, len(self.rel_names))
self.basis = dglnn.WeightBasis(
(in_feat, out_feat), num_bases, len(self.rel_names)
)
else:
self.weight = nn.Parameter(th.Tensor(len(self.rel_names), in_feat, out_feat))
nn.init.xavier_uniform_(self.weight, gain=nn.init.calculate_gain('relu'))
self.weight = nn.Parameter(
th.Tensor(len(self.rel_names), in_feat, out_feat)
)
nn.init.xavier_uniform_(
self.weight, gain=nn.init.calculate_gain("relu")
)
# bias
if bias:
......@@ -182,8 +213,9 @@ class RelGraphConvLayerHeteroAPI(nn.Module):
# weight for self loop
if self.self_loop:
self.loop_weight = nn.Parameter(th.Tensor(in_feat, out_feat))
nn.init.xavier_uniform_(self.loop_weight,
gain=nn.init.calculate_gain('relu'))
nn.init.xavier_uniform_(
self.loop_weight, gain=nn.init.calculate_gain("relu")
)
self.dropout = nn.Dropout(dropout)
......@@ -205,29 +237,33 @@ class RelGraphConvLayerHeteroAPI(nn.Module):
g = g.local_var()
if self.use_weight:
weight = self.basis() if self.use_basis else self.weight
wdict = {self.rel_names[i] : {'weight' : w.squeeze(0)}
for i, w in enumerate(th.split(weight, 1, dim=0))}
wdict = {
self.rel_names[i]: {"weight": w.squeeze(0)}
for i, w in enumerate(th.split(weight, 1, dim=0))
}
else:
wdict = {}
inputs_src = inputs_dst = inputs
for srctype,_,_ in g.canonical_etypes:
g.nodes[srctype].data['h'] = inputs[srctype]
for srctype, _, _ in g.canonical_etypes:
g.nodes[srctype].data["h"] = inputs[srctype]
if self.use_weight:
g.apply_edges(fn.copy_u('h', 'm'))
m = g.edata['m']
g.apply_edges(fn.copy_u("h", "m"))
m = g.edata["m"]
for rel in g.canonical_etypes:
_, etype, _ = rel
g.edges[rel].data['h*w_r'] = th.matmul(m[rel], wdict[etype]['weight'])
g.edges[rel].data["h*w_r"] = th.matmul(
m[rel], wdict[etype]["weight"]
)
else:
g.apply_edges(fn.copy_u('h', 'h*w_r'))
g.apply_edges(fn.copy_u("h", "h*w_r"))
g.update_all(fn.copy_e('h*w_r', 'm'), fn.sum('m', 'h'))
g.update_all(fn.copy_e("h*w_r", "m"), fn.sum("m", "h"))
def _apply(ntype):
h = g.nodes[ntype].data['h']
h = g.nodes[ntype].data["h"]
if self.self_loop:
h = h + th.matmul(inputs_dst[ntype], self.loop_weight)
if self.bias:
......@@ -235,16 +271,16 @@ class RelGraphConvLayerHeteroAPI(nn.Module):
if self.activation:
h = self.activation(h)
return self.dropout(h)
return {ntype : _apply(ntype) for ntype in g.dsttypes}
return {ntype: _apply(ntype) for ntype in g.dsttypes}
class RelGraphEmbed(nn.Module):
r"""Embedding layer for featureless heterograph."""
def __init__(self,
g,
embed_size,
embed_name='embed',
activation=None,
dropout=0.0):
def __init__(
self, g, embed_size, embed_name="embed", activation=None, dropout=0.0
):
super(RelGraphEmbed, self).__init__()
self.g = g
self.embed_size = embed_size
......@@ -255,8 +291,10 @@ class RelGraphEmbed(nn.Module):
# create weight embeddings for each node for each relation
self.embeds = nn.ParameterDict()
for ntype in g.ntypes:
embed = nn.Parameter(th.Tensor(g.number_of_nodes(ntype), self.embed_size))
nn.init.xavier_uniform_(embed, gain=nn.init.calculate_gain('relu'))
embed = nn.Parameter(
th.Tensor(g.number_of_nodes(ntype), self.embed_size)
)
nn.init.xavier_uniform_(embed, gain=nn.init.calculate_gain("relu"))
self.embeds[ntype] = embed
def forward(self, block=None):
......@@ -276,14 +314,18 @@ class RelGraphEmbed(nn.Module):
"""
return self.embeds
class EntityClassify(nn.Module):
def __init__(self,
g,
h_dim, out_dim,
num_bases,
num_hidden_layers=1,
dropout=0,
use_self_loop=False):
def __init__(
self,
g,
h_dim,
out_dim,
num_bases,
num_hidden_layers=1,
dropout=0,
use_self_loop=False,
):
super(EntityClassify, self).__init__()
self.g = g
self.h_dim = h_dim
......@@ -301,21 +343,42 @@ class EntityClassify(nn.Module):
self.embed_layer = RelGraphEmbed(g, self.h_dim)
self.layers = nn.ModuleList()
# i2h
self.layers.append(RelGraphConvLayer(
self.h_dim, self.h_dim, self.rel_names,
self.num_bases, activation=F.relu, self_loop=self.use_self_loop,
dropout=self.dropout, weight=False))
self.layers.append(
RelGraphConvLayer(
self.h_dim,
self.h_dim,
self.rel_names,
self.num_bases,
activation=F.relu,
self_loop=self.use_self_loop,
dropout=self.dropout,
weight=False,
)
)
# h2h
for i in range(self.num_hidden_layers):
self.layers.append(RelGraphConvLayer(
self.h_dim, self.h_dim, self.rel_names,
self.num_bases, activation=F.relu, self_loop=self.use_self_loop,
dropout=self.dropout))
self.layers.append(
RelGraphConvLayer(
self.h_dim,
self.h_dim,
self.rel_names,
self.num_bases,
activation=F.relu,
self_loop=self.use_self_loop,
dropout=self.dropout,
)
)
# h2o
self.layers.append(RelGraphConvLayer(
self.h_dim, self.out_dim, self.rel_names,
self.num_bases, activation=None,
self_loop=self.use_self_loop))
self.layers.append(
RelGraphConvLayer(
self.h_dim,
self.out_dim,
self.rel_names,
self.num_bases,
activation=None,
self_loop=self.use_self_loop,
)
)
def forward(self, h=None, blocks=None):
if h is None:
......@@ -346,8 +409,10 @@ class EntityClassify(nn.Module):
y = {
k: th.zeros(
g.number_of_nodes(k),
self.h_dim if l != len(self.layers) - 1 else self.out_dim)
for k in g.ntypes}
self.h_dim if l != len(self.layers) - 1 else self.out_dim,
)
for k in g.ntypes
}
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)
dataloader = dgl.dataloading.DataLoader(
......@@ -357,12 +422,16 @@ class EntityClassify(nn.Module):
batch_size=batch_size,
shuffle=True,
drop_last=False,
num_workers=num_workers)
num_workers=num_workers,
)
for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):
block = blocks[0].to(device)
h = {k: x[k][input_nodes[k]].to(device) for k in input_nodes.keys()}
h = {
k: x[k][input_nodes[k]].to(device)
for k in input_nodes.keys()
}
h = layer(block, h)
for k in output_nodes.keys():
......@@ -371,14 +440,18 @@ class EntityClassify(nn.Module):
x = y
return y
class EntityClassify_HeteroAPI(nn.Module):
def __init__(self,
g,
h_dim, out_dim,
num_bases,
num_hidden_layers=1,
dropout=0,
use_self_loop=False):
def __init__(
self,
g,
h_dim,
out_dim,
num_bases,
num_hidden_layers=1,
dropout=0,
use_self_loop=False,
):
super(EntityClassify_HeteroAPI, self).__init__()
self.g = g
self.h_dim = h_dim
......@@ -396,21 +469,42 @@ class EntityClassify_HeteroAPI(nn.Module):
self.embed_layer = RelGraphEmbed(g, self.h_dim)
self.layers = nn.ModuleList()
# i2h
self.layers.append(RelGraphConvLayerHeteroAPI(
self.h_dim, self.h_dim, self.rel_names,
self.num_bases, activation=F.relu, self_loop=self.use_self_loop,
dropout=self.dropout, weight=False))
self.layers.append(
RelGraphConvLayerHeteroAPI(
self.h_dim,
self.h_dim,
self.rel_names,
self.num_bases,
activation=F.relu,
self_loop=self.use_self_loop,
dropout=self.dropout,
weight=False,
)
)
# h2h
for i in range(self.num_hidden_layers):
self.layers.append(RelGraphConvLayerHeteroAPI(
self.h_dim, self.h_dim, self.rel_names,
self.num_bases, activation=F.relu, self_loop=self.use_self_loop,
dropout=self.dropout))
self.layers.append(
RelGraphConvLayerHeteroAPI(
self.h_dim,
self.h_dim,
self.rel_names,
self.num_bases,
activation=F.relu,
self_loop=self.use_self_loop,
dropout=self.dropout,
)
)
# h2o
self.layers.append(RelGraphConvLayerHeteroAPI(
self.h_dim, self.out_dim, self.rel_names,
self.num_bases, activation=None,
self_loop=self.use_self_loop))
self.layers.append(
RelGraphConvLayerHeteroAPI(
self.h_dim,
self.out_dim,
self.rel_names,
self.num_bases,
activation=None,
self_loop=self.use_self_loop,
)
)
def forward(self, h=None, blocks=None):
if h is None:
......@@ -441,8 +535,10 @@ class EntityClassify_HeteroAPI(nn.Module):
y = {
k: th.zeros(
g.number_of_nodes(k),
self.h_dim if l != len(self.layers) - 1 else self.out_dim)
for k in g.ntypes}
self.h_dim if l != len(self.layers) - 1 else self.out_dim,
)
for k in g.ntypes
}
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)
dataloader = dgl.dataloading.DataLoader(
......@@ -452,12 +548,16 @@ class EntityClassify_HeteroAPI(nn.Module):
batch_size=batch_size,
shuffle=True,
drop_last=False,
num_workers=num_workers)
num_workers=num_workers,
)
for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):
block = blocks[0].to(device)
h = {k: x[k][input_nodes[k]].to(device) for k in input_nodes.keys()}
h = {
k: x[k][input_nodes[k]].to(device)
for k in input_nodes.keys()
}
h = layer(block, h)
for k in h.keys():
......
"""Infering Relational Data with Graph Convolutional Networks
"""
import argparse
import torch as th
from functools import partial
import torch.nn.functional as F
from dgl.data.rdf import AIFB, MUTAG, BGS, AM
import torch as th
import torch.nn.functional as F
from entity_classify import EntityClassify
from dgl.data.rdf import AIFB, AM, BGS, MUTAG
def main(args):
# load graph data
if args.dataset == 'aifb':
if args.dataset == "aifb":
dataset = AIFBDataset()
elif args.dataset == 'mutag':
elif args.dataset == "mutag":
dataset = MUTAGDataset()
elif args.dataset == 'bgs':
elif args.dataset == "bgs":
dataset = BGSDataset()
elif args.dataset == 'am':
elif args.dataset == "am":
dataset = AMDataset()
else:
raise ValueError()
......@@ -24,9 +26,9 @@ def main(args):
g = dataset[0]
category = dataset.predict_category
num_classes = dataset.num_classes
test_mask = g.nodes[category].data.pop('test_mask')
test_mask = g.nodes[category].data.pop("test_mask")
test_idx = th.nonzero(test_mask, as_tuple=False).squeeze()
labels = g.nodes[category].data.pop('labels')
labels = g.nodes[category].data.pop("labels")
# check cuda
use_cuda = args.gpu >= 0 and th.cuda.is_available()
......@@ -34,15 +36,17 @@ def main(args):
th.cuda.set_device(args.gpu)
labels = labels.cuda()
test_idx = test_idx.cuda()
g = g.to('cuda:%d' % args.gpu)
g = g.to("cuda:%d" % args.gpu)
# create model
model = EntityClassify(g,
args.n_hidden,
num_classes,
num_bases=args.n_bases,
num_hidden_layers=args.n_layers - 2,
use_self_loop=args.use_self_loop)
model = EntityClassify(
g,
args.n_hidden,
num_classes,
num_bases=args.n_bases,
num_hidden_layers=args.n_layers - 2,
use_self_loop=args.use_self_loop,
)
model.load_state_dict(th.load(args.model_path))
if use_cuda:
model.cuda()
......@@ -51,28 +55,45 @@ def main(args):
model.eval()
logits = model.forward()[category]
test_loss = F.cross_entropy(logits[test_idx], labels[test_idx])
test_acc = th.sum(logits[test_idx].argmax(dim=1) == labels[test_idx]).item() / len(test_idx)
print("Test Acc: {:.4f} | Test loss: {:.4f}".format(test_acc, test_loss.item()))
test_acc = th.sum(
logits[test_idx].argmax(dim=1) == labels[test_idx]
).item() / len(test_idx)
print(
"Test Acc: {:.4f} | Test loss: {:.4f}".format(
test_acc, test_loss.item()
)
)
print()
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='RGCN')
parser.add_argument("--n-hidden", type=int, default=16,
help="number of hidden units")
parser.add_argument("--gpu", type=int, default=-1,
help="gpu")
parser.add_argument("--lr", type=float, default=1e-2,
help="learning rate")
parser.add_argument("--n-bases", type=int, default=-1,
help="number of filter weight matrices, default: -1 [use all]")
parser.add_argument("--n-layers", type=int, default=2,
help="number of propagation rounds")
parser.add_argument("-d", "--dataset", type=str, required=True,
help="dataset to use")
parser.add_argument("--model_path", type=str,
help='path of the model to load from')
parser.add_argument("--use-self-loop", default=False, action='store_true',
help="include self feature as a special relation")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="RGCN")
parser.add_argument(
"--n-hidden", type=int, default=16, help="number of hidden units"
)
parser.add_argument("--gpu", type=int, default=-1, help="gpu")
parser.add_argument("--lr", type=float, default=1e-2, help="learning rate")
parser.add_argument(
"--n-bases",
type=int,
default=-1,
help="number of filter weight matrices, default: -1 [use all]",
)
parser.add_argument(
"--n-layers", type=int, default=2, help="number of propagation rounds"
)
parser.add_argument(
"-d", "--dataset", type=str, required=True, help="dataset to use"
)
parser.add_argument(
"--model_path", type=str, help="path of the model to load from"
)
parser.add_argument(
"--use-self-loop",
default=False,
action="store_true",
help="include self feature as a special relation",
)
args = parser.parse_args()
print(args)
......
......@@ -9,6 +9,7 @@ References:
import torch
from torch import nn
import dgl.function as fn
......@@ -21,26 +22,23 @@ class RRNLayer(nn.Module):
def forward(self, g):
g.apply_edges(self.get_msg)
g.edata['e'] = self.edge_dropout(g.edata['e'])
g.update_all(message_func=fn.copy_e('e', 'msg'),
reduce_func=fn.sum('msg', 'm'))
g.edata["e"] = self.edge_dropout(g.edata["e"])
g.update_all(
message_func=fn.copy_e("e", "msg"), reduce_func=fn.sum("msg", "m")
)
g.apply_nodes(self.node_update)
def get_msg(self, edges):
e = torch.cat([edges.src['h'], edges.dst['h']], -1)
e = torch.cat([edges.src["h"], edges.dst["h"]], -1)
e = self.msg_layer(e)
return {'e': e}
return {"e": e}
def node_update(self, nodes):
return self.node_update_func(nodes)
class RRN(nn.Module):
def __init__(self,
msg_layer,
node_update_func,
num_steps,
edge_drop):
def __init__(self, msg_layer, node_update_func, num_steps, edge_drop):
super(RRN, self).__init__()
self.num_steps = num_steps
self.rrn_layer = RRNLayer(msg_layer, node_update_func, edge_drop)
......@@ -50,11 +48,9 @@ class RRN(nn.Module):
for _ in range(self.num_steps):
self.rrn_layer(g)
if get_all_outputs:
outputs.append(g.ndata['h'])
outputs.append(g.ndata["h"])
if get_all_outputs:
outputs = torch.stack(outputs, 0) # num_steps x n_nodes x h_dim
else:
outputs = g.ndata['h'] # n_nodes x h_dim
outputs = g.ndata["h"] # n_nodes x h_dim
return outputs
......@@ -2,17 +2,13 @@
SudokuNN module based on RRN for solving sudoku puzzles
"""
import torch
from rrn import RRN
from torch import nn
import torch
class SudokuNN(nn.Module):
def __init__(self,
num_steps,
embed_size=16,
hidden_dim=96,
edge_drop=0.1):
def __init__(self, num_steps, embed_size=16, hidden_dim=96, edge_drop=0.1):
super(SudokuNN, self).__init__()
self.num_steps = num_steps
......@@ -21,7 +17,7 @@ class SudokuNN(nn.Module):
self.col_embed = nn.Embedding(9, embed_size)
self.input_layer = nn.Sequential(
nn.Linear(3*embed_size, hidden_dim),
nn.Linear(3 * embed_size, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
......@@ -30,10 +26,10 @@ class SudokuNN(nn.Module):
nn.Linear(hidden_dim, hidden_dim),
)
self.lstm = nn.LSTMCell(hidden_dim*2, hidden_dim, bias=False)
self.lstm = nn.LSTMCell(hidden_dim * 2, hidden_dim, bias=False)
msg_layer = nn.Sequential(
nn.Linear(2*hidden_dim, hidden_dim),
nn.Linear(2 * hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
......@@ -49,17 +45,17 @@ class SudokuNN(nn.Module):
self.loss_func = nn.CrossEntropyLoss()
def forward(self, g, is_training=True):
labels = g.ndata.pop('a')
labels = g.ndata.pop("a")
input_digits = self.digit_embed(g.ndata.pop('q'))
rows = self.row_embed(g.ndata.pop('row'))
cols = self.col_embed(g.ndata.pop('col'))
input_digits = self.digit_embed(g.ndata.pop("q"))
rows = self.row_embed(g.ndata.pop("row"))
cols = self.col_embed(g.ndata.pop("col"))
x = self.input_layer(torch.cat([input_digits, rows, cols], -1))
g.ndata['x'] = x
g.ndata['h'] = x
g.ndata['rnn_h'] = torch.zeros_like(x, dtype=torch.float)
g.ndata['rnn_c'] = torch.zeros_like(x, dtype=torch.float)
g.ndata["x"] = x
g.ndata["h"] = x
g.ndata["rnn_h"] = torch.zeros_like(x, dtype=torch.float)
g.ndata["rnn_c"] = torch.zeros_like(x, dtype=torch.float)
outputs = self.rrn(g, is_training)
logits = self.output_layer(outputs)
......@@ -67,15 +63,18 @@ class SudokuNN(nn.Module):
preds = torch.argmax(logits, -1)
if is_training:
labels = torch.stack([labels]*self.num_steps, 0)
labels = torch.stack([labels] * self.num_steps, 0)
logits = logits.view([-1, 10])
labels = labels.view([-1])
loss = self.loss_func(logits, labels)
return preds, loss
def node_update_func(self, nodes):
x, h, m, c = nodes.data['x'], nodes.data['rnn_h'], nodes.data['m'], nodes.data['rnn_c']
x, h, m, c = (
nodes.data["x"],
nodes.data["rnn_h"],
nodes.data["m"],
nodes.data["rnn_c"],
)
new_h, new_c = self.lstm(torch.cat([x, m], -1), (h, c))
return {'h': new_h, 'rnn_c': new_c, 'rnn_h': new_h}
return {"h": new_h, "rnn_c": new_c, "rnn_h": new_h}
......@@ -2,24 +2,28 @@ import csv
import os
import urllib.request
import zipfile
from copy import copy
import numpy as np
import torch
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.dataset import Dataset
import torch
import dgl
from copy import copy
def _basic_sudoku_graph():
grids = [[0, 1, 2, 9, 10, 11, 18, 19, 20],
[3, 4, 5, 12, 13, 14, 21, 22, 23],
[6, 7, 8, 15, 16, 17, 24, 25, 26],
[27, 28, 29, 36, 37, 38, 45, 46, 47],
[30, 31, 32, 39, 40, 41, 48, 49, 50],
[33, 34, 35, 42, 43, 44, 51, 52, 53],
[54, 55, 56, 63, 64, 65, 72, 73, 74],
[57, 58, 59, 66, 67, 68, 75, 76, 77],
[60, 61, 62, 69, 70, 71, 78, 79, 80]]
grids = [
[0, 1, 2, 9, 10, 11, 18, 19, 20],
[3, 4, 5, 12, 13, 14, 21, 22, 23],
[6, 7, 8, 15, 16, 17, 24, 25, 26],
[27, 28, 29, 36, 37, 38, 45, 46, 47],
[30, 31, 32, 39, 40, 41, 48, 49, 50],
[33, 34, 35, 42, 43, 44, 51, 52, 53],
[54, 55, 56, 63, 64, 65, 72, 73, 74],
[57, 58, 59, 66, 67, 68, 75, 76, 77],
[60, 61, 62, 69, 70, 71, 78, 79, 80],
]
edges = set()
for i in range(81):
row, col = i // 9, i % 9
......@@ -33,7 +37,7 @@ def _basic_sudoku_graph():
col_src += 9
# same grid
grid_row, grid_col = row // 3, col // 3
for n in grids[grid_row*3 + grid_col]:
for n in grids[grid_row * 3 + grid_col]:
if n != i:
edges.add((n, i))
edges = list(edges)
......@@ -53,26 +57,26 @@ class ListDataset(Dataset):
return len(self.lists_of_data[0])
def _get_sudoku_dataset(segment='train'):
assert segment in ['train', 'valid', 'test']
def _get_sudoku_dataset(segment="train"):
assert segment in ["train", "valid", "test"]
url = "https://data.dgl.ai/dataset/sudoku-hard.zip"
zip_fname = "/tmp/sudoku-hard.zip"
dest_dir = '/tmp/sudoku-hard/'
dest_dir = "/tmp/sudoku-hard/"
if not os.path.exists(dest_dir):
print("Downloading data...")
urllib.request.urlretrieve(url, zip_fname)
with zipfile.ZipFile(zip_fname) as f:
f.extractall('/tmp/')
f.extractall("/tmp/")
def read_csv(fname):
print("Reading %s..." % fname)
with open(dest_dir + fname) as f:
reader = csv.reader(f, delimiter=',')
reader = csv.reader(f, delimiter=",")
return [(q, a) for q, a in reader]
data = read_csv(segment + '.csv')
data = read_csv(segment + ".csv")
def encode(samples):
def parse(x):
......@@ -82,12 +86,12 @@ def _get_sudoku_dataset(segment='train'):
return encoded
data = encode(data)
print(f'Number of puzzles in {segment} set : {len(data)}')
print(f"Number of puzzles in {segment} set : {len(data)}")
return data
def sudoku_dataloader(batch_size, segment='train'):
def sudoku_dataloader(batch_size, segment="train"):
"""
Get a DataLoader instance for dataset of sudoku. Every iteration of the dataloader returns
a DGLGraph instance, the ndata of the graph contains:
......@@ -104,7 +108,7 @@ def sudoku_dataloader(batch_size, segment='train'):
q, a = zip(*data)
dataset = ListDataset(q, a)
if segment == 'train':
if segment == "train":
data_sampler = RandomSampler(dataset)
else:
data_sampler = SequentialSampler(dataset)
......@@ -120,13 +124,15 @@ def sudoku_dataloader(batch_size, segment='train'):
q = torch.tensor(q, dtype=torch.long)
a = torch.tensor(a, dtype=torch.long)
graph = copy(basic_graph)
graph.ndata['q'] = q # q means question
graph.ndata['a'] = a # a means answer
graph.ndata['row'] = torch.tensor(rows, dtype=torch.long)
graph.ndata['col'] = torch.tensor(cols, dtype=torch.long)
graph.ndata["q"] = q # q means question
graph.ndata["a"] = a # a means answer
graph.ndata["row"] = torch.tensor(rows, dtype=torch.long)
graph.ndata["col"] = torch.tensor(cols, dtype=torch.long)
graph_list.append(graph)
batch_graph = dgl.batch(graph_list)
return batch_graph
dataloader = DataLoader(dataset, batch_size, sampler=data_sampler, collate_fn=collate_fn)
dataloader = DataLoader(
dataset, batch_size, sampler=data_sampler, collate_fn=collate_fn
)
return dataloader
import os
import urllib.request
import torch
import numpy as np
from sudoku_data import _basic_sudoku_graph
import torch
from sudoku import SudokuNN
from sudoku_data import _basic_sudoku_graph
def solve_sudoku(puzzle):
......@@ -14,18 +14,18 @@ def solve_sudoku(puzzle):
:return: a [9, 9] shaped numpy array
"""
puzzle = np.array(puzzle, dtype=np.long).reshape([-1])
model_path = 'ckpt'
model_path = "ckpt"
if not os.path.exists(model_path):
os.mkdir(model_path)
model_filename = os.path.join(model_path, 'rrn-sudoku.pkl')
model_filename = os.path.join(model_path, "rrn-sudoku.pkl")
if not os.path.exists(model_filename):
print('Downloading model...')
url = 'https://data.dgl.ai/models/rrn-sudoku.pkl'
print("Downloading model...")
url = "https://data.dgl.ai/models/rrn-sudoku.pkl"
urllib.request.urlretrieve(url, model_filename)
model = SudokuNN(num_steps=64, edge_drop=0.)
model.load_state_dict(torch.load(model_filename, map_location='cpu'))
model = SudokuNN(num_steps=64, edge_drop=0.0)
model.load_state_dict(torch.load(model_filename, map_location="cpu"))
model.eval()
g = _basic_sudoku_graph()
......@@ -33,17 +33,17 @@ def solve_sudoku(puzzle):
rows = sudoku_indices // 9
cols = sudoku_indices % 9
g.ndata['row'] = torch.tensor(rows, dtype=torch.long)
g.ndata['col'] = torch.tensor(cols, dtype=torch.long)
g.ndata['q'] = torch.tensor(puzzle, dtype=torch.long)
g.ndata['a'] = torch.tensor(puzzle, dtype=torch.long)
g.ndata["row"] = torch.tensor(rows, dtype=torch.long)
g.ndata["col"] = torch.tensor(cols, dtype=torch.long)
g.ndata["q"] = torch.tensor(puzzle, dtype=torch.long)
g.ndata["a"] = torch.tensor(puzzle, dtype=torch.long)
pred, _ = model(g, False)
pred = pred.cpu().data.numpy().reshape([9, 9])
return pred
if __name__ == '__main__':
if __name__ == "__main__":
q = [
[9, 7, 0, 4, 0, 2, 0, 5, 3],
[0, 4, 6, 0, 9, 0, 0, 0, 0],
......@@ -53,7 +53,7 @@ if __name__ == '__main__':
[0, 0, 2, 8, 0, 0, 0, 0, 0],
[6, 0, 5, 1, 0, 7, 2, 0, 0],
[0, 0, 0, 0, 6, 0, 7, 4, 0],
[4, 3, 0, 2, 0, 9, 0, 6, 1]
[4, 3, 0, 2, 0, 9, 0, 6, 1],
]
answer = solve_sudoku(q)
......
from sudoku_data import sudoku_dataloader
import argparse
from sudoku import SudokuNN
import torch
from torch.optim import Adam
import os
import numpy as np
import torch
from sudoku import SudokuNN
from sudoku_data import sudoku_dataloader
from torch.optim import Adam
def main(args):
if args.gpu < 0 or not torch.cuda.is_available():
device = torch.device('cpu')
device = torch.device("cpu")
else:
device = torch.device('cuda', args.gpu)
device = torch.device("cuda", args.gpu)
model = SudokuNN(num_steps=args.steps, edge_drop=args.edge_drop)
......@@ -19,10 +20,12 @@ def main(args):
if not os.path.exists(args.output_dir):
os.mkdir(args.output_dir)
model.to(device)
train_dataloader = sudoku_dataloader(args.batch_size, segment='train')
dev_dataloader = sudoku_dataloader(args.batch_size, segment='valid')
train_dataloader = sudoku_dataloader(args.batch_size, segment="train")
dev_dataloader = sudoku_dataloader(args.batch_size, segment="valid")
opt = Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
opt = Adam(
model.parameters(), lr=args.lr, weight_decay=args.weight_decay
)
best_dev_acc = 0.0
for epoch in range(args.epochs):
......@@ -43,7 +46,7 @@ def main(args):
dev_res = []
for g in dev_dataloader:
g = g.to(device)
target = g.ndata['a']
target = g.ndata["a"]
target = target.view([-1, 81])
with torch.no_grad():
......@@ -51,28 +54,35 @@ def main(args):
preds = preds.view([-1, 81])
for i in range(preds.size(0)):
dev_res.append(int(torch.equal(preds[i, :], target[i, :])))
dev_res.append(
int(torch.equal(preds[i, :], target[i, :]))
)
dev_loss.append(loss.cpu().detach().data)
dev_acc = sum(dev_res) / len(dev_res)
print(f"Dev loss {np.mean(dev_loss)}, accuracy {dev_acc}")
if dev_acc >= best_dev_acc:
torch.save(model.state_dict(), os.path.join(args.output_dir, 'model_best.bin'))
torch.save(
model.state_dict(),
os.path.join(args.output_dir, "model_best.bin"),
)
best_dev_acc = dev_acc
print(f"Best dev accuracy {best_dev_acc}\n")
torch.save(model.state_dict(), os.path.join(args.output_dir, 'model_final.bin'))
torch.save(
model.state_dict(), os.path.join(args.output_dir, "model_final.bin")
)
if args.do_eval:
model_path = os.path.join(args.output_dir, 'model_best.bin')
model_path = os.path.join(args.output_dir, "model_best.bin")
if not os.path.exists(model_path):
raise FileNotFoundError("Saved model not Found!")
model.load_state_dict(torch.load(model_path))
model.to(device)
test_dataloader = sudoku_dataloader(args.batch_size, segment='test')
test_dataloader = sudoku_dataloader(args.batch_size, segment="test")
print("\n=========Test step========")
model.eval()
......@@ -80,7 +90,7 @@ def main(args):
test_res = []
for g in test_dataloader:
g = g.to(device)
target = g.ndata['a']
target = g.ndata["a"]
target = target.view([-1, 81])
with torch.no_grad():
......@@ -98,30 +108,44 @@ def main(args):
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Recurrent Relational Network on sudoku task.')
parser.add_argument("--output_dir", type=str, default=None, required=True,
help="The directory to save model")
parser.add_argument("--do_train", default=False, action="store_true",
help="Train the model")
parser.add_argument("--do_eval", default=False, action="store_true",
help="Evaluate the model on test data")
parser.add_argument("--epochs", type=int, default=100,
help="Number of training epochs")
parser.add_argument("--batch_size", type=int, default=64,
help="Batch size")
parser.add_argument("--edge_drop", type=float, default=0.4,
help="Dropout rate at edges.")
parser.add_argument("--steps", type=int, default=32,
help="Number of message passing steps.")
parser.add_argument("--gpu", type=int, default=-1,
help="gpu")
parser.add_argument("--lr", type=float, default=2e-4,
help="Learning rate")
parser.add_argument("--weight_decay", type=float, default=1e-4,
help="weight decay (L2 penalty)")
parser = argparse.ArgumentParser(
description="Recurrent Relational Network on sudoku task."
)
parser.add_argument(
"--output_dir",
type=str,
default=None,
required=True,
help="The directory to save model",
)
parser.add_argument(
"--do_train", default=False, action="store_true", help="Train the model"
)
parser.add_argument(
"--do_eval",
default=False,
action="store_true",
help="Evaluate the model on test data",
)
parser.add_argument(
"--epochs", type=int, default=100, help="Number of training epochs"
)
parser.add_argument("--batch_size", type=int, default=64, help="Batch size")
parser.add_argument(
"--edge_drop", type=float, default=0.4, help="Dropout rate at edges."
)
parser.add_argument(
"--steps", type=int, default=32, help="Number of message passing steps."
)
parser.add_argument("--gpu", type=int, default=-1, help="gpu")
parser.add_argument("--lr", type=float, default=2e-4, help="Learning rate")
parser.add_argument(
"--weight_decay",
type=float,
default=1e-4,
help="weight decay (L2 penalty)",
)
args = parser.parse_args()
main(args)
import json
import os
from copy import deepcopy
from main import main, parse_args
from utils import get_stats
def load_config(path="./grid_search_config.json"):
with open(path, "r") as f:
return json.load(f)
def run_experiments(args):
res = []
for i in range(args.num_trials):
print("Trial {}/{}".format(i + 1, args.num_trials))
acc, _ = main(args)
res.append(acc)
mean, err_bd = get_stats(res, conf_interval=True)
return mean, err_bd
def grid_search(config:dict):
def grid_search(config: dict):
args = parse_args()
results = {}
for d in config["dataset"]:
args.dataset = d
best_acc, err_bd = 0., 0.
best_acc, err_bd = 0.0, 0.0
best_args = vars(args)
for arch in config["arch"]:
args.architecture = arch
......@@ -47,9 +51,10 @@ def grid_search(config:dict):
args.output_path = "./output/{}.log".format(d)
result = {
"params": best_args,
"result": "{:.4f}({:.4f})".format(best_acc, err_bd)
"result": "{:.4f}({:.4f})".format(best_acc, err_bd),
}
with open(args.output_path, "w") as f:
json.dump(result, f, sort_keys=True, indent=4)
grid_search(load_config())
import torch
import torch.nn.functional as F
from utils import get_batch_id, topk
import dgl
from dgl.nn import GraphConv, AvgPooling, MaxPooling
from utils import topk, get_batch_id
from dgl.nn import AvgPooling, GraphConv, MaxPooling
class SAGPool(torch.nn.Module):
"""The Self-Attention Pooling layer in paper
"""The Self-Attention Pooling layer in paper
`Self Attention Graph Pooling <https://arxiv.org/pdf/1904.08082.pdf>`
Args:
......@@ -18,16 +19,28 @@ class SAGPool(torch.nn.Module):
non_linearity (Callable, optional): The non-linearity function, a pytorch function.
(default: :obj:`torch.tanh`)
"""
def __init__(self, in_dim:int, ratio=0.5, conv_op=GraphConv, non_linearity=torch.tanh):
def __init__(
self,
in_dim: int,
ratio=0.5,
conv_op=GraphConv,
non_linearity=torch.tanh,
):
super(SAGPool, self).__init__()
self.in_dim = in_dim
self.ratio = ratio
self.score_layer = conv_op(in_dim, 1)
self.non_linearity = non_linearity
def forward(self, graph:dgl.DGLGraph, feature:torch.Tensor):
def forward(self, graph: dgl.DGLGraph, feature: torch.Tensor):
score = self.score_layer(graph, feature).squeeze()
perm, next_batch_num_nodes = topk(score, self.ratio, get_batch_id(graph.batch_num_nodes()), graph.batch_num_nodes())
perm, next_batch_num_nodes = topk(
score,
self.ratio,
get_batch_id(graph.batch_num_nodes()),
graph.batch_num_nodes(),
)
feature = feature[perm] * self.non_linearity(score[perm]).view(-1, 1)
graph = dgl.node_subgraph(graph, perm)
......@@ -37,7 +50,7 @@ class SAGPool(torch.nn.Module):
# Since global pooling has nothing to do with 'batch_num_edges',
# we can leave it to be None or unchanged.
graph.set_batch_num_nodes(next_batch_num_nodes)
return graph, feature, perm
......@@ -45,15 +58,18 @@ class ConvPoolBlock(torch.nn.Module):
"""A combination of GCN layer and SAGPool layer,
followed by a concatenated (mean||sum) readout operation.
"""
def __init__(self, in_dim:int, out_dim:int, pool_ratio=0.8):
def __init__(self, in_dim: int, out_dim: int, pool_ratio=0.8):
super(ConvPoolBlock, self).__init__()
self.conv = GraphConv(in_dim, out_dim)
self.pool = SAGPool(out_dim, ratio=pool_ratio)
self.avgpool = AvgPooling()
self.maxpool = MaxPooling()
self.maxpool = MaxPooling()
def forward(self, graph, feature):
out = F.relu(self.conv(graph, feature))
graph, out, _ = self.pool(graph, out)
g_out = torch.cat([self.avgpool(graph, out), self.maxpool(graph, out)], dim=-1)
g_out = torch.cat(
[self.avgpool(graph, out), self.maxpool(graph, out)], dim=-1
)
return graph, out, g_out
......@@ -3,54 +3,78 @@ import json
import logging
import os
from time import time
import dgl
import torch
import torch.nn
import torch.nn.functional as F
from dgl.data import LegacyTUDataset
from dgl.dataloading import GraphDataLoader
from torch.utils.data import random_split
from network import get_sag_network
from torch.utils.data import random_split
from utils import get_stats
import dgl
from dgl.data import LegacyTUDataset
from dgl.dataloading import GraphDataLoader
def parse_args():
parser = argparse.ArgumentParser(description="Self-Attention Graph Pooling")
parser.add_argument("--dataset", type=str, default="DD",
choices=["DD", "PROTEINS", "NCI1", "NCI109", "Mutagenicity"],
help="DD/PROTEINS/NCI1/NCI109/Mutagenicity")
parser.add_argument("--batch_size", type=int, default=128,
help="batch size")
parser.add_argument("--lr", type=float, default=5e-4,
help="learning rate")
parser.add_argument("--weight_decay", type=float, default=1e-4,
help="weight decay")
parser.add_argument("--pool_ratio", type=float, default=0.5,
help="pooling ratio")
parser.add_argument("--hid_dim", type=int, default=128,
help="hidden size")
parser.add_argument("--dropout", type=float, default=0.5,
help="dropout ratio")
parser.add_argument("--epochs", type=int, default=100000,
help="max number of training epochs")
parser.add_argument("--patience", type=int, default=50,
help="patience for early stopping")
parser.add_argument("--device", type=int, default=-1,
help="device id, -1 for cpu")
parser.add_argument("--architecture", type=str, default="hierarchical",
choices=["hierarchical", "global"],
help="model architecture")
parser.add_argument("--dataset_path", type=str, default="./dataset",
help="path to dataset")
parser.add_argument("--conv_layers", type=int, default=3,
help="number of conv layers")
parser.add_argument("--print_every", type=int, default=10,
help="print trainlog every k epochs, -1 for silent training")
parser.add_argument("--num_trials", type=int, default=1,
help="number of trials")
parser.add_argument(
"--dataset",
type=str,
default="DD",
choices=["DD", "PROTEINS", "NCI1", "NCI109", "Mutagenicity"],
help="DD/PROTEINS/NCI1/NCI109/Mutagenicity",
)
parser.add_argument(
"--batch_size", type=int, default=128, help="batch size"
)
parser.add_argument("--lr", type=float, default=5e-4, help="learning rate")
parser.add_argument(
"--weight_decay", type=float, default=1e-4, help="weight decay"
)
parser.add_argument(
"--pool_ratio", type=float, default=0.5, help="pooling ratio"
)
parser.add_argument("--hid_dim", type=int, default=128, help="hidden size")
parser.add_argument(
"--dropout", type=float, default=0.5, help="dropout ratio"
)
parser.add_argument(
"--epochs",
type=int,
default=100000,
help="max number of training epochs",
)
parser.add_argument(
"--patience", type=int, default=50, help="patience for early stopping"
)
parser.add_argument(
"--device", type=int, default=-1, help="device id, -1 for cpu"
)
parser.add_argument(
"--architecture",
type=str,
default="hierarchical",
choices=["hierarchical", "global"],
help="model architecture",
)
parser.add_argument(
"--dataset_path", type=str, default="./dataset", help="path to dataset"
)
parser.add_argument(
"--conv_layers", type=int, default=3, help="number of conv layers"
)
parser.add_argument(
"--print_every",
type=int,
default=10,
help="print trainlog every k epochs, -1 for silent training",
)
parser.add_argument(
"--num_trials", type=int, default=1, help="number of trials"
)
parser.add_argument("--output_path", type=str, default="./output")
args = parser.parse_args()
# device
......@@ -69,15 +93,21 @@ def parse_args():
if not os.path.exists(args.output_path):
os.makedirs(args.output_path)
name = "Data={}_Hidden={}_Arch={}_Pool={}_WeightDecay={}_Lr={}.log".format(
args.dataset, args.hid_dim, args.architecture, args.pool_ratio, args.weight_decay, args.lr)
args.dataset,
args.hid_dim,
args.architecture,
args.pool_ratio,
args.weight_decay,
args.lr,
)
args.output_path = os.path.join(args.output_path, name)
return args
def train(model:torch.nn.Module, optimizer, trainloader, device):
def train(model: torch.nn.Module, optimizer, trainloader, device):
model.train()
total_loss = 0.
total_loss = 0.0
num_batches = len(trainloader)
for batch in trainloader:
optimizer.zero_grad()
......@@ -90,15 +120,15 @@ def train(model:torch.nn.Module, optimizer, trainloader, device):
optimizer.step()
total_loss += loss.item()
return total_loss / num_batches
@torch.no_grad()
def test(model:torch.nn.Module, loader, device):
def test(model: torch.nn.Module, loader, device):
model.eval()
correct = 0.
loss = 0.
correct = 0.0
loss = 0.0
num_graphs = 0
for batch in loader:
batch_graphs, batch_labels = batch
......@@ -124,29 +154,45 @@ def main(args):
num_training = int(len(dataset) * 0.8)
num_val = int(len(dataset) * 0.1)
num_test = len(dataset) - num_val - num_training
train_set, val_set, test_set = random_split(dataset, [num_training, num_val, num_test])
train_loader = GraphDataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=6)
val_loader = GraphDataLoader(val_set, batch_size=args.batch_size, num_workers=2)
test_loader = GraphDataLoader(test_set, batch_size=args.batch_size, num_workers=2)
train_set, val_set, test_set = random_split(
dataset, [num_training, num_val, num_test]
)
train_loader = GraphDataLoader(
train_set, batch_size=args.batch_size, shuffle=True, num_workers=6
)
val_loader = GraphDataLoader(
val_set, batch_size=args.batch_size, num_workers=2
)
test_loader = GraphDataLoader(
test_set, batch_size=args.batch_size, num_workers=2
)
device = torch.device(args.device)
# Step 2: Create model =================================================================== #
num_feature, num_classes, _ = dataset.statistics()
model_op = get_sag_network(args.architecture)
model = model_op(in_dim=num_feature, hid_dim=args.hid_dim, out_dim=num_classes,
num_convs=args.conv_layers, pool_ratio=args.pool_ratio, dropout=args.dropout).to(device)
model = model_op(
in_dim=num_feature,
hid_dim=args.hid_dim,
out_dim=num_classes,
num_convs=args.conv_layers,
pool_ratio=args.pool_ratio,
dropout=args.dropout,
).to(device)
args.num_feature = int(num_feature)
args.num_classes = int(num_classes)
# Step 3: Create training components ===================================================== #
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
optimizer = torch.optim.Adam(
model.parameters(), lr=args.lr, weight_decay=args.weight_decay
)
# Step 4: training epoches =============================================================== #
bad_cound = 0
best_val_loss = float("inf")
final_test_acc = 0.
final_test_acc = 0.0
best_epoch = 0
train_times = []
for e in range(args.epochs):
......@@ -164,11 +210,17 @@ def main(args):
bad_cound += 1
if bad_cound >= args.patience:
break
if (e + 1) % args.print_every == 0:
log_format = "Epoch {}: loss={:.4f}, val_acc={:.4f}, final_test_acc={:.4f}"
log_format = (
"Epoch {}: loss={:.4f}, val_acc={:.4f}, final_test_acc={:.4f}"
)
print(log_format.format(e + 1, train_loss, val_acc, final_test_acc))
print("Best Epoch {}, final test acc {:.4f}".format(best_epoch, final_test_acc))
print(
"Best Epoch {}, final test acc {:.4f}".format(
best_epoch, final_test_acc
)
)
return final_test_acc, sum(train_times) / len(train_times)
......@@ -185,9 +237,11 @@ if __name__ == "__main__":
mean, err_bd = get_stats(res)
print("mean acc: {:.4f}, error bound: {:.4f}".format(mean, err_bd))
out_dict = {"hyper-parameters": vars(args),
"result": "{:.4f}(+-{:.4f})".format(mean, err_bd),
"train_time": "{:.4f}".format(sum(train_times) / len(train_times))}
out_dict = {
"hyper-parameters": vars(args),
"result": "{:.4f}(+-{:.4f})".format(mean, err_bd),
"train_time": "{:.4f}".format(sum(train_times) / len(train_times)),
}
with open(args.output_path, "w") as f:
json.dump(out_dict, f, sort_keys=True, indent=4)
import torch
import torch.nn
import torch.nn.functional as F
import dgl
from dgl.nn import GraphConv, AvgPooling, MaxPooling
from layer import ConvPoolBlock, SAGPool
import dgl
from dgl.nn import AvgPooling, GraphConv, MaxPooling
class SAGNetworkHierarchical(torch.nn.Module):
"""The Self-Attention Graph Pooling Network with hierarchical readout in paper
......@@ -21,25 +21,35 @@ class SAGNetworkHierarchical(torch.nn.Module):
remain after pooling. (default: :obj:`0.5`)
dropout (float, optional): The dropout ratio for each layer. (default: 0)
"""
def __init__(self, in_dim:int, hid_dim:int, out_dim:int, num_convs=3,
pool_ratio:float=0.5, dropout:float=0.0):
def __init__(
self,
in_dim: int,
hid_dim: int,
out_dim: int,
num_convs=3,
pool_ratio: float = 0.5,
dropout: float = 0.0,
):
super(SAGNetworkHierarchical, self).__init__()
self.dropout = dropout
self.num_convpools = num_convs
convpools = []
for i in range(num_convs):
_i_dim = in_dim if i == 0 else hid_dim
_o_dim = hid_dim
convpools.append(ConvPoolBlock(_i_dim, _o_dim, pool_ratio=pool_ratio))
convpools.append(
ConvPoolBlock(_i_dim, _o_dim, pool_ratio=pool_ratio)
)
self.convpools = torch.nn.ModuleList(convpools)
self.lin1 = torch.nn.Linear(hid_dim * 2, hid_dim)
self.lin2 = torch.nn.Linear(hid_dim, hid_dim // 2)
self.lin3 = torch.nn.Linear(hid_dim // 2, out_dim)
def forward(self, graph:dgl.DGLGraph):
def forward(self, graph: dgl.DGLGraph):
feat = graph.ndata["feat"]
final_readout = None
......@@ -49,12 +59,12 @@ class SAGNetworkHierarchical(torch.nn.Module):
final_readout = readout
else:
final_readout = final_readout + readout
feat = F.relu(self.lin1(final_readout))
feat = F.dropout(feat, p=self.dropout, training=self.training)
feat = F.relu(self.lin2(feat))
feat = F.log_softmax(self.lin3(feat), dim=-1)
return feat
......@@ -72,8 +82,16 @@ class SAGNetworkGlobal(torch.nn.Module):
remain after pooling. (default: :obj:`0.5`)
dropout (float, optional): The dropout ratio for each layer. (default: 0)
"""
def __init__(self, in_dim:int, hid_dim:int, out_dim:int, num_convs=3,
pool_ratio:float=0.5, dropout:float=0.0):
def __init__(
self,
in_dim: int,
hid_dim: int,
out_dim: int,
num_convs=3,
pool_ratio: float = 0.5,
dropout: float = 0.0,
):
super(SAGNetworkGlobal, self).__init__()
self.dropout = dropout
self.num_convs = num_convs
......@@ -93,18 +111,21 @@ class SAGNetworkGlobal(torch.nn.Module):
self.lin1 = torch.nn.Linear(concat_dim * 2, hid_dim)
self.lin2 = torch.nn.Linear(hid_dim, hid_dim // 2)
self.lin3 = torch.nn.Linear(hid_dim // 2, out_dim)
def forward(self, graph:dgl.DGLGraph):
def forward(self, graph: dgl.DGLGraph):
feat = graph.ndata["feat"]
conv_res = []
for i in range(self.num_convs):
feat = self.convs[i](graph, feat)
conv_res.append(feat)
conv_res = torch.cat(conv_res, dim=-1)
graph, feat, _ = self.pool(graph, conv_res)
feat = torch.cat([self.avg_readout(graph, feat), self.max_readout(graph, feat)], dim=-1)
feat = torch.cat(
[self.avg_readout(graph, feat), self.max_readout(graph, feat)],
dim=-1,
)
feat = F.relu(self.lin1(feat))
feat = F.dropout(feat, p=self.dropout, training=self.training)
......@@ -114,10 +135,12 @@ class SAGNetworkGlobal(torch.nn.Module):
return feat
def get_sag_network(net_type:str="hierarchical"):
def get_sag_network(net_type: str = "hierarchical"):
if net_type == "hierarchical":
return SAGNetworkHierarchical
elif net_type == "global":
return SAGNetworkGlobal
else:
raise ValueError("SAGNetwork type {} is not supported.".format(net_type))
raise ValueError(
"SAGNetwork type {} is not supported.".format(net_type)
)
import torch
import logging
from scipy.stats import t
import math
import torch
from scipy.stats import t
def get_stats(array, conf_interval=False, name=None, stdout=False, logout=False):
def get_stats(
array, conf_interval=False, name=None, stdout=False, logout=False
):
"""Compute mean and standard deviation from an numerical array
Args:
array (array like obj): The numerical array, this array can be
array (array like obj): The numerical array, this array can be
convert to :obj:`torch.Tensor`.
conf_interval (bool, optional): If True, compute the confidence interval bound (95%)
instead of the std value. (default: :obj:`False`)
name (str, optional): The name of this numerical array, for log usage.
(default: :obj:`None`)
stdout (bool, optional): Whether to output result to the terminal.
stdout (bool, optional): Whether to output result to the terminal.
(default: :obj:`False`)
logout (bool, optional): Whether to output result via logging module.
(default: :obj:`False`)
......@@ -29,7 +32,7 @@ def get_stats(array, conf_interval=False, name=None, stdout=False, logout=False)
if conf_interval:
n = array.size(0)
se = std / (math.sqrt(n) + eps)
t_value = t.ppf(0.975, df=n-1)
t_value = t.ppf(0.975, df=n - 1)
err_bound = t_value * se
else:
err_bound = std
......@@ -46,7 +49,7 @@ def get_stats(array, conf_interval=False, name=None, stdout=False, logout=False)
return center, err_bound
def get_batch_id(num_nodes:torch.Tensor):
def get_batch_id(num_nodes: torch.Tensor):
"""Convert the num_nodes array obtained from batch graph to batch_id array
for each node.
......@@ -57,12 +60,19 @@ def get_batch_id(num_nodes:torch.Tensor):
batch_size = num_nodes.size(0)
batch_ids = []
for i in range(batch_size):
item = torch.full((num_nodes[i],), i, dtype=torch.long, device=num_nodes.device)
item = torch.full(
(num_nodes[i],), i, dtype=torch.long, device=num_nodes.device
)
batch_ids.append(item)
return torch.cat(batch_ids)
def topk(x:torch.Tensor, ratio:float, batch_id:torch.Tensor, num_nodes:torch.Tensor):
def topk(
x: torch.Tensor,
ratio: float,
batch_id: torch.Tensor,
num_nodes: torch.Tensor,
):
"""The top-k pooling method. Given a graph batch, this method will pool out some
nodes from input node feature tensor for each graph according to the given ratio.
......@@ -72,21 +82,23 @@ def topk(x:torch.Tensor, ratio:float, batch_id:torch.Tensor, num_nodes:torch.Ten
tensor will be pooled out.
batch_id (torch.Tensor): The batch_id of each element in the input tensor.
num_nodes (torch.Tensor): The number of nodes of each graph in batch.
Returns:
perm (torch.Tensor): The index in batch to be kept.
k (torch.Tensor): The remaining number of nodes for each graph.
"""
batch_size, max_num_nodes = num_nodes.size(0), num_nodes.max().item()
cum_num_nodes = torch.cat(
[num_nodes.new_zeros(1),
num_nodes.cumsum(dim=0)[:-1]], dim=0)
[num_nodes.new_zeros(1), num_nodes.cumsum(dim=0)[:-1]], dim=0
)
index = torch.arange(batch_id.size(0), dtype=torch.long, device=x.device)
index = (index - cum_num_nodes[batch_id]) + (batch_id * max_num_nodes)
dense_x = x.new_full((batch_size * max_num_nodes, ), torch.finfo(x.dtype).min)
dense_x = x.new_full(
(batch_size * max_num_nodes,), torch.finfo(x.dtype).min
)
dense_x[index] = x
dense_x = dense_x.view(batch_size, max_num_nodes)
......@@ -96,8 +108,10 @@ def topk(x:torch.Tensor, ratio:float, batch_id:torch.Tensor, num_nodes:torch.Ten
k = (ratio * num_nodes.to(torch.float)).ceil().to(torch.long)
mask = [
torch.arange(k[i], dtype=torch.long, device=x.device) +
i * max_num_nodes for i in range(batch_size)]
torch.arange(k[i], dtype=torch.long, device=x.device)
+ i * max_num_nodes
for i in range(batch_size)
]
mask = torch.cat(mask, dim=0)
perm = perm[mask]
......
import logging
import time
import os
import time
def _transform_log_level(str_level):
if str_level == 'info':
if str_level == "info":
return logging.INFO
elif str_level == 'warning':
elif str_level == "warning":
return logging.WARNING
elif str_level == 'critical':
elif str_level == "critical":
return logging.CRITICAL
elif str_level == 'debug':
elif str_level == "debug":
return logging.DEBUG
elif str_level == 'error':
elif str_level == "error":
return logging.ERROR
else:
raise KeyError('Log level error')
raise KeyError("Log level error")
class LightLogging(object):
def __init__(self, log_path=None, log_name='lightlog', log_level='debug'):
def __init__(self, log_path=None, log_name="lightlog", log_level="debug"):
log_level = _transform_log_level(log_level)
if log_path:
if not log_path.endswith('/'):
log_path += '/'
if not log_path.endswith("/"):
log_path += "/"
if not os.path.exists(log_path):
os.mkdir(log_path)
if log_name.endswith('-') or log_name.endswith('_'):
log_name = log_path+log_name + time.strftime('%Y-%m-%d-%H:%M', time.localtime(time.time())) + '.log'
if log_name.endswith("-") or log_name.endswith("_"):
log_name = (
log_path
+ log_name
+ time.strftime(
"%Y-%m-%d-%H:%M", time.localtime(time.time())
)
+ ".log"
)
else:
log_name = log_path+log_name + '_' + time.strftime('%Y-%m-%d-%H-%M', time.localtime(time.time())) + '.log'
log_name = (
log_path
+ log_name
+ "_"
+ time.strftime(
"%Y-%m-%d-%H-%M", time.localtime(time.time())
)
+ ".log"
)
logging.basicConfig(level=log_level,
format="%(asctime)s %(levelname)s: %(message)s",
datefmt='%Y-%m-%d-%H:%M',
handlers=[
logging.FileHandler(log_name, mode='w'),
logging.StreamHandler()
])
logging.info('Start Logging')
logging.info('Log file path: {}'.format(log_name))
logging.basicConfig(
level=log_level,
format="%(asctime)s %(levelname)s: %(message)s",
datefmt="%Y-%m-%d-%H:%M",
handlers=[
logging.FileHandler(log_name, mode="w"),
logging.StreamHandler(),
],
)
logging.info("Start Logging")
logging.info("Log file path: {}".format(log_name))
else:
logging.basicConfig(level=log_level,
format="%(asctime)s %(levelname)s: %(message)s",
datefmt='%Y-%m-%d-%H:%M',
handlers=[
logging.StreamHandler()
])
logging.info('Start Logging')
logging.basicConfig(
level=log_level,
format="%(asctime)s %(levelname)s: %(message)s",
datefmt="%Y-%m-%d-%H:%M",
handlers=[logging.StreamHandler()],
)
logging.info("Start Logging")
def debug(self, msg):
logging.debug(msg)
......@@ -66,4 +83,4 @@ class LightLogging(object):
logging.warning(msg)
def error(self, msg):
logging.error(msg)
\ No newline at end of file
logging.error(msg)
import time
from tqdm import tqdm
import numpy as np
import torch
import torch.multiprocessing
from logger import LightLogging
from model import DGCNN, GCN
from sampler import SEALData
from torch.nn import BCEWithLogitsLoss
from dgl import NID, EID
from tqdm import tqdm
from utils import evaluate_hits, load_ogb_dataset, parse_arguments
from dgl import EID, NID
from dgl.dataloading import GraphDataLoader
from utils import parse_arguments
from utils import load_ogb_dataset, evaluate_hits
from sampler import SEALData
from model import GCN, DGCNN
from logger import LightLogging
import torch.multiprocessing
torch.multiprocessing.set_sharing_strategy('file_system')
'''
torch.multiprocessing.set_sharing_strategy("file_system")
"""
Part of the code are adapted from
https://github.com/facebookresearch/SEAL_OGB
'''
def train(model, dataloader, loss_fn, optimizer, device, num_graphs=32, total_graphs=None):
"""
def train(
model,
dataloader,
loss_fn,
optimizer,
device,
num_graphs=32,
total_graphs=None,
):
model.train()
total_loss = 0
......@@ -27,7 +37,7 @@ def train(model, dataloader, loss_fn, optimizer, device, num_graphs=32, total_gr
g = g.to(device)
labels = labels.to(device)
optimizer.zero_grad()
logits = model(g, g.ndata['z'], g.ndata[NID], g.edata[EID])
logits = model(g, g.ndata["z"], g.ndata[NID], g.edata[EID])
loss = loss_fn(logits, labels)
loss.backward()
optimizer.step()
......@@ -43,7 +53,7 @@ def evaluate(model, dataloader, device):
y_pred, y_true = [], []
for g, labels in tqdm(dataloader, ncols=100):
g = g.to(device)
logits = model(g, g.ndata['z'], g.ndata[NID], g.edata[EID])
logits = model(g, g.ndata["z"], g.ndata[NID], g.edata[EID])
y_pred.append(logits.view(-1).cpu())
y_true.append(labels.view(-1).cpu().to(torch.float))
......@@ -62,7 +72,7 @@ def main(args, print_fn=print):
else:
torch.manual_seed(123)
# Load dataset
if args.dataset.startswith('ogbl'):
if args.dataset.startswith("ogbl"):
graph, split_edge = load_ogb_dataset(args.dataset)
else:
raise NotImplementedError
......@@ -71,11 +81,11 @@ def main(args, print_fn=print):
# set gpu
if args.gpu_id >= 0 and torch.cuda.is_available():
device = 'cuda:{}'.format(args.gpu_id)
device = "cuda:{}".format(args.gpu_id)
else:
device = 'cpu'
device = "cpu"
if args.dataset == 'ogbl-collab':
if args.dataset == "ogbl-collab":
# ogbl-collab dataset is multi-edge graph
use_coalesce = True
else:
......@@ -83,94 +93,135 @@ def main(args, print_fn=print):
# Generate positive and negative edges and corresponding labels
# Sampling subgraphs and generate node labeling features
seal_data = SEALData(g=graph, split_edge=split_edge, hop=args.hop, neg_samples=args.neg_samples,
subsample_ratio=args.subsample_ratio, use_coalesce=use_coalesce, prefix=args.dataset,
save_dir=args.save_dir, num_workers=args.num_workers, print_fn=print_fn)
node_attribute = seal_data.ndata['feat']
edge_weight = seal_data.edata['weight'].float()
train_data = seal_data('train')
val_data = seal_data('valid')
test_data = seal_data('test')
seal_data = SEALData(
g=graph,
split_edge=split_edge,
hop=args.hop,
neg_samples=args.neg_samples,
subsample_ratio=args.subsample_ratio,
use_coalesce=use_coalesce,
prefix=args.dataset,
save_dir=args.save_dir,
num_workers=args.num_workers,
print_fn=print_fn,
)
node_attribute = seal_data.ndata["feat"]
edge_weight = seal_data.edata["weight"].float()
train_data = seal_data("train")
val_data = seal_data("valid")
test_data = seal_data("test")
train_graphs = len(train_data.graph_list)
# Set data loader
train_loader = GraphDataLoader(train_data, batch_size=args.batch_size, num_workers=args.num_workers)
val_loader = GraphDataLoader(val_data, batch_size=args.batch_size, num_workers=args.num_workers)
test_loader = GraphDataLoader(test_data, batch_size=args.batch_size, num_workers=args.num_workers)
train_loader = GraphDataLoader(
train_data, batch_size=args.batch_size, num_workers=args.num_workers
)
val_loader = GraphDataLoader(
val_data, batch_size=args.batch_size, num_workers=args.num_workers
)
test_loader = GraphDataLoader(
test_data, batch_size=args.batch_size, num_workers=args.num_workers
)
# set model
if args.model == 'gcn':
model = GCN(num_layers=args.num_layers,
hidden_units=args.hidden_units,
gcn_type=args.gcn_type,
pooling_type=args.pooling,
node_attributes=node_attribute,
edge_weights=edge_weight,
node_embedding=None,
use_embedding=True,
num_nodes=num_nodes,
dropout=args.dropout)
elif args.model == 'dgcnn':
model = DGCNN(num_layers=args.num_layers,
hidden_units=args.hidden_units,
k=args.sort_k,
gcn_type=args.gcn_type,
node_attributes=node_attribute,
edge_weights=edge_weight,
node_embedding=None,
use_embedding=True,
num_nodes=num_nodes,
dropout=args.dropout)
if args.model == "gcn":
model = GCN(
num_layers=args.num_layers,
hidden_units=args.hidden_units,
gcn_type=args.gcn_type,
pooling_type=args.pooling,
node_attributes=node_attribute,
edge_weights=edge_weight,
node_embedding=None,
use_embedding=True,
num_nodes=num_nodes,
dropout=args.dropout,
)
elif args.model == "dgcnn":
model = DGCNN(
num_layers=args.num_layers,
hidden_units=args.hidden_units,
k=args.sort_k,
gcn_type=args.gcn_type,
node_attributes=node_attribute,
edge_weights=edge_weight,
node_embedding=None,
use_embedding=True,
num_nodes=num_nodes,
dropout=args.dropout,
)
else:
raise ValueError('Model error')
raise ValueError("Model error")
model = model.to(device)
parameters = model.parameters()
optimizer = torch.optim.Adam(parameters, lr=args.lr)
loss_fn = BCEWithLogitsLoss()
print_fn("Total parameters: {}".format(sum([p.numel() for p in model.parameters()])))
print_fn(
"Total parameters: {}".format(
sum([p.numel() for p in model.parameters()])
)
)
# train and evaluate loop
summary_val = []
summary_test = []
for epoch in range(args.epochs):
start_time = time.time()
loss = train(model=model,
dataloader=train_loader,
loss_fn=loss_fn,
optimizer=optimizer,
device=device,
num_graphs=args.batch_size,
total_graphs=train_graphs)
loss = train(
model=model,
dataloader=train_loader,
loss_fn=loss_fn,
optimizer=optimizer,
device=device,
num_graphs=args.batch_size,
total_graphs=train_graphs,
)
train_time = time.time()
if epoch % args.eval_steps == 0:
val_pos_pred, val_neg_pred = evaluate(model=model,
dataloader=val_loader,
device=device)
test_pos_pred, test_neg_pred = evaluate(model=model,
dataloader=test_loader,
device=device)
val_metric = evaluate_hits(args.dataset, val_pos_pred, val_neg_pred, args.hits_k)
test_metric = evaluate_hits(args.dataset, test_pos_pred, test_neg_pred, args.hits_k)
val_pos_pred, val_neg_pred = evaluate(
model=model, dataloader=val_loader, device=device
)
test_pos_pred, test_neg_pred = evaluate(
model=model, dataloader=test_loader, device=device
)
val_metric = evaluate_hits(
args.dataset, val_pos_pred, val_neg_pred, args.hits_k
)
test_metric = evaluate_hits(
args.dataset, test_pos_pred, test_neg_pred, args.hits_k
)
evaluate_time = time.time()
print_fn("Epoch-{}, train loss: {:.4f}, hits@{}: val-{:.4f}, test-{:.4f}, "
"cost time: train-{:.1f}s, total-{:.1f}s".format(epoch, loss, args.hits_k, val_metric, test_metric,
train_time - start_time,
evaluate_time - start_time))
print_fn(
"Epoch-{}, train loss: {:.4f}, hits@{}: val-{:.4f}, test-{:.4f}, "
"cost time: train-{:.1f}s, total-{:.1f}s".format(
epoch,
loss,
args.hits_k,
val_metric,
test_metric,
train_time - start_time,
evaluate_time - start_time,
)
)
summary_val.append(val_metric)
summary_test.append(test_metric)
summary_test = np.array(summary_test)
print_fn("Experiment Results:")
print_fn("Best hits@{}: {:.4f}, epoch: {}".format(args.hits_k, np.max(summary_test), np.argmax(summary_test)))
print_fn(
"Best hits@{}: {:.4f}, epoch: {}".format(
args.hits_k, np.max(summary_test), np.argmax(summary_test)
)
)
if __name__ == '__main__':
if __name__ == "__main__":
args = parse_arguments()
logger = LightLogging(log_name='SEAL', log_path='./logs')
logger = LightLogging(log_name="SEAL", log_path="./logs")
main(args, logger.info)
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn.pytorch import SortPooling, SumPooling
from dgl.nn.pytorch import GraphConv, SAGEConv
from dgl.nn.pytorch import GraphConv, SAGEConv, SortPooling, SumPooling
class GCN(nn.Module):
......@@ -26,9 +26,20 @@ class GCN(nn.Module):
"""
def __init__(self, num_layers, hidden_units, gcn_type='gcn', pooling_type='sum', node_attributes=None,
edge_weights=None, node_embedding=None, use_embedding=False,
num_nodes=None, dropout=0.5, max_z=1000):
def __init__(
self,
num_layers,
hidden_units,
gcn_type="gcn",
pooling_type="sum",
node_attributes=None,
edge_weights=None,
node_embedding=None,
use_embedding=False,
num_nodes=None,
dropout=0.5,
max_z=1000,
):
super(GCN, self).__init__()
self.num_layers = num_layers
self.dropout = dropout
......@@ -39,10 +50,14 @@ class GCN(nn.Module):
self.z_embedding = nn.Embedding(max_z, hidden_units)
if node_attributes is not None:
self.node_attributes_lookup = nn.Embedding.from_pretrained(node_attributes)
self.node_attributes_lookup = nn.Embedding.from_pretrained(
node_attributes
)
self.node_attributes_lookup.weight.requires_grad = False
if edge_weights is not None:
self.edge_weights_lookup = nn.Embedding.from_pretrained(edge_weights)
self.edge_weights_lookup = nn.Embedding.from_pretrained(
edge_weights
)
self.edge_weights_lookup.weight.requires_grad = False
if node_embedding is not None:
self.node_embedding = nn.Embedding.from_pretrained(node_embedding)
......@@ -57,21 +72,31 @@ class GCN(nn.Module):
initial_dim += self.node_embedding.embedding_dim
self.layers = nn.ModuleList()
if gcn_type == 'gcn':
self.layers.append(GraphConv(initial_dim, hidden_units, allow_zero_in_degree=True))
if gcn_type == "gcn":
self.layers.append(
GraphConv(initial_dim, hidden_units, allow_zero_in_degree=True)
)
for _ in range(num_layers - 1):
self.layers.append(GraphConv(hidden_units, hidden_units, allow_zero_in_degree=True))
elif gcn_type == 'sage':
self.layers.append(SAGEConv(initial_dim, hidden_units, aggregator_type='gcn'))
self.layers.append(
GraphConv(
hidden_units, hidden_units, allow_zero_in_degree=True
)
)
elif gcn_type == "sage":
self.layers.append(
SAGEConv(initial_dim, hidden_units, aggregator_type="gcn")
)
for _ in range(num_layers - 1):
self.layers.append(SAGEConv(hidden_units, hidden_units, aggregator_type='gcn'))
self.layers.append(
SAGEConv(hidden_units, hidden_units, aggregator_type="gcn")
)
else:
raise ValueError('Gcn type error.')
raise ValueError("Gcn type error.")
self.linear_1 = nn.Linear(hidden_units, hidden_units)
self.linear_2 = nn.Linear(hidden_units, 1)
if pooling_type != 'sum':
raise ValueError('Pooling type error.')
if pooling_type != "sum":
raise ValueError("Pooling type error.")
self.pooling = SumPooling()
def reset_parameters(self):
......@@ -141,8 +166,20 @@ class DGCNN(nn.Module):
max_z(int, optional): default max vocab size of node labeling, default 1000.
"""
def __init__(self, num_layers, hidden_units, k=10, gcn_type='gcn', node_attributes=None,
edge_weights=None, node_embedding=None, use_embedding=False, num_nodes=None, dropout=0.5, max_z=1000):
def __init__(
self,
num_layers,
hidden_units,
k=10,
gcn_type="gcn",
node_attributes=None,
edge_weights=None,
node_embedding=None,
use_embedding=False,
num_nodes=None,
dropout=0.5,
max_z=1000,
):
super(DGCNN, self).__init__()
self.num_layers = num_layers
self.dropout = dropout
......@@ -153,10 +190,14 @@ class DGCNN(nn.Module):
self.z_embedding = nn.Embedding(max_z, hidden_units)
if node_attributes is not None:
self.node_attributes_lookup = nn.Embedding.from_pretrained(node_attributes)
self.node_attributes_lookup = nn.Embedding.from_pretrained(
node_attributes
)
self.node_attributes_lookup.weight.requires_grad = False
if edge_weights is not None:
self.edge_weights_lookup = nn.Embedding.from_pretrained(edge_weights)
self.edge_weights_lookup = nn.Embedding.from_pretrained(
edge_weights
)
self.edge_weights_lookup.weight.requires_grad = False
if node_embedding is not None:
self.node_embedding = nn.Embedding.from_pretrained(node_embedding)
......@@ -171,28 +212,42 @@ class DGCNN(nn.Module):
initial_dim += self.node_embedding.embedding_dim
self.layers = nn.ModuleList()
if gcn_type == 'gcn':
self.layers.append(GraphConv(initial_dim, hidden_units, allow_zero_in_degree=True))
if gcn_type == "gcn":
self.layers.append(
GraphConv(initial_dim, hidden_units, allow_zero_in_degree=True)
)
for _ in range(num_layers - 1):
self.layers.append(GraphConv(hidden_units, hidden_units, allow_zero_in_degree=True))
self.layers.append(GraphConv(hidden_units, 1, allow_zero_in_degree=True))
elif gcn_type == 'sage':
self.layers.append(SAGEConv(initial_dim, hidden_units, aggregator_type='gcn'))
self.layers.append(
GraphConv(
hidden_units, hidden_units, allow_zero_in_degree=True
)
)
self.layers.append(
GraphConv(hidden_units, 1, allow_zero_in_degree=True)
)
elif gcn_type == "sage":
self.layers.append(
SAGEConv(initial_dim, hidden_units, aggregator_type="gcn")
)
for _ in range(num_layers - 1):
self.layers.append(SAGEConv(hidden_units, hidden_units, aggregator_type='gcn'))
self.layers.append(SAGEConv(hidden_units, 1, aggregator_type='gcn'))
self.layers.append(
SAGEConv(hidden_units, hidden_units, aggregator_type="gcn")
)
self.layers.append(SAGEConv(hidden_units, 1, aggregator_type="gcn"))
else:
raise ValueError('Gcn type error.')
raise ValueError("Gcn type error.")
self.pooling = SortPooling(k=k)
conv1d_channels = [16, 32]
total_latent_dim = hidden_units * num_layers + 1
conv1d_kws = [total_latent_dim, 5]
self.conv_1 = nn.Conv1d(1, conv1d_channels[0], conv1d_kws[0],
conv1d_kws[0])
self.conv_1 = nn.Conv1d(
1, conv1d_channels[0], conv1d_kws[0], conv1d_kws[0]
)
self.maxpool1d = nn.MaxPool1d(2, 2)
self.conv_2 = nn.Conv1d(conv1d_channels[0], conv1d_channels[1],
conv1d_kws[1], 1)
self.conv_2 = nn.Conv1d(
conv1d_channels[0], conv1d_channels[1], conv1d_kws[1], 1
)
dense_dim = int((k - 2) / 2 + 1)
dense_dim = (dense_dim - conv1d_kws[1] + 1) * conv1d_channels[1]
self.linear_1 = nn.Linear(dense_dim, 128)
......
import os.path as osp
from tqdm import tqdm
from copy import deepcopy
import torch
import dgl
from torch.utils.data import DataLoader, Dataset
from dgl import DGLGraph, NID
from dgl.dataloading.negative_sampler import Uniform
from dgl import add_self_loop
from tqdm import tqdm
from utils import drnl_node_labeling
import dgl
from dgl import NID, DGLGraph, add_self_loop
from dgl.dataloading.negative_sampler import Uniform
class GraphDataSet(Dataset):
"""
......@@ -37,7 +38,9 @@ class PosNegEdgesGenerator(object):
shuffle(bool): if shuffle generated graph list
"""
def __init__(self, g, split_edge, neg_samples=1, subsample_ratio=0.1, shuffle=True):
def __init__(
self, g, split_edge, neg_samples=1, subsample_ratio=0.1, shuffle=True
):
self.neg_sampler = Uniform(neg_samples)
self.subsample_ratio = subsample_ratio
self.split_edge = split_edge
......@@ -46,24 +49,29 @@ class PosNegEdgesGenerator(object):
def __call__(self, split_type):
if split_type == 'train':
if split_type == "train":
subsample_ratio = self.subsample_ratio
else:
subsample_ratio = 1
pos_edges = self.split_edge[split_type]['edge']
if split_type == 'train':
pos_edges = self.split_edge[split_type]["edge"]
if split_type == "train":
# Adding self loop in train avoids sampling the source node itself.
g = add_self_loop(self.g)
eids = g.edge_ids(pos_edges[:, 0], pos_edges[:, 1])
neg_edges = torch.stack(self.neg_sampler(g, eids), dim=1)
else:
neg_edges = self.split_edge[split_type]['edge_neg']
neg_edges = self.split_edge[split_type]["edge_neg"]
pos_edges = self.subsample(pos_edges, subsample_ratio).long()
neg_edges = self.subsample(neg_edges, subsample_ratio).long()
edges = torch.cat([pos_edges, neg_edges])
labels = torch.cat([torch.ones(pos_edges.size(0), 1), torch.zeros(neg_edges.size(0), 1)])
labels = torch.cat(
[
torch.ones(pos_edges.size(0), 1),
torch.zeros(neg_edges.size(0), 1),
]
)
if self.shuffle:
perm = torch.randperm(edges.size(0))
edges = edges[perm]
......@@ -84,7 +92,7 @@ class PosNegEdgesGenerator(object):
num_edges = edges.size(0)
perm = torch.randperm(num_edges)
perm = perm[:int(subsample_ratio * num_edges)]
perm = perm[: int(subsample_ratio * num_edges)]
edges = edges[perm]
return edges
......@@ -144,8 +152,16 @@ class SEALSampler(object):
subgraph = dgl.node_subgraph(self.graph, sample_nodes)
# Each node should have unique node id in the new subgraph
u_id = int(torch.nonzero(subgraph.ndata[NID] == int(target_nodes[0]), as_tuple=False))
v_id = int(torch.nonzero(subgraph.ndata[NID] == int(target_nodes[1]), as_tuple=False))
u_id = int(
torch.nonzero(
subgraph.ndata[NID] == int(target_nodes[0]), as_tuple=False
)
)
v_id = int(
torch.nonzero(
subgraph.ndata[NID] == int(target_nodes[1]), as_tuple=False
)
)
# remove link between target nodes in positive subgraphs.
if subgraph.has_edges_between(u_id, v_id):
......@@ -156,7 +172,7 @@ class SEALSampler(object):
subgraph.remove_edges(link_id)
z = drnl_node_labeling(subgraph, u_id, v_id)
subgraph.ndata['z'] = z
subgraph.ndata["z"] = z
return subgraph
......@@ -171,10 +187,19 @@ class SEALSampler(object):
def __call__(self, edges, labels):
subgraph_list = []
labels_list = []
edge_dataset = EdgeDataSet(edges, labels, transform=self.sample_subgraph)
self.print_fn('Using {} workers in sampling job.'.format(self.num_workers))
sampler = DataLoader(edge_dataset, batch_size=32, num_workers=self.num_workers,
shuffle=False, collate_fn=self._collate)
edge_dataset = EdgeDataSet(
edges, labels, transform=self.sample_subgraph
)
self.print_fn(
"Using {} workers in sampling job.".format(self.num_workers)
)
sampler = DataLoader(
edge_dataset,
batch_size=32,
num_workers=self.num_workers,
shuffle=False,
collate_fn=self._collate,
)
for subgraph, label in tqdm(sampler, ncols=100):
label_copy = deepcopy(label)
subgraph = dgl.unbatch(subgraph)
......@@ -200,8 +225,20 @@ class SEALData(object):
use_coalesce(bool): True for coalesce graph. Graph with multi-edge need to coalesce
"""
def __init__(self, g, split_edge, hop=1, neg_samples=1, subsample_ratio=1, prefix=None, save_dir=None,
num_workers=32, shuffle=True, use_coalesce=True, print_fn=print):
def __init__(
self,
g,
split_edge,
hop=1,
neg_samples=1,
subsample_ratio=1,
prefix=None,
save_dir=None,
num_workers=32,
shuffle=True,
use_coalesce=True,
print_fn=print,
):
self.g = g
self.hop = hop
self.subsample_ratio = subsample_ratio
......@@ -209,15 +246,19 @@ class SEALData(object):
self.save_dir = save_dir
self.print_fn = print_fn
self.generator = PosNegEdgesGenerator(g=self.g,
split_edge=split_edge,
neg_samples=neg_samples,
subsample_ratio=subsample_ratio,
shuffle=shuffle)
self.generator = PosNegEdgesGenerator(
g=self.g,
split_edge=split_edge,
neg_samples=neg_samples,
subsample_ratio=subsample_ratio,
shuffle=shuffle,
)
if use_coalesce:
for k, v in g.edata.items():
g.edata[k] = v.float() # dgl.to_simple() requires data is float
self.g = dgl.to_simple(g, copy_ndata=True, copy_edata=True, aggregator='sum')
self.g = dgl.to_simple(
g, copy_ndata=True, copy_edata=True, aggregator="sum"
)
self.ndata = {k: v for k, v in self.g.ndata.items()}
self.edata = {k: v for k, v in self.g.edata.items()}
......@@ -226,25 +267,28 @@ class SEALData(object):
self.print_fn("Save ndata and edata in class.")
self.print_fn("Clear ndata and edata in graph.")
self.sampler = SEALSampler(graph=self.g,
hop=hop,
num_workers=num_workers,
print_fn=print_fn)
self.sampler = SEALSampler(
graph=self.g, hop=hop, num_workers=num_workers, print_fn=print_fn
)
def __call__(self, split_type):
if split_type == 'train':
if split_type == "train":
subsample_ratio = self.subsample_ratio
else:
subsample_ratio = 1
path = osp.join(self.save_dir or '', '{}_{}_{}-hop_{}-subsample.bin'.format(self.prefix, split_type,
self.hop, subsample_ratio))
path = osp.join(
self.save_dir or "",
"{}_{}_{}-hop_{}-subsample.bin".format(
self.prefix, split_type, self.hop, subsample_ratio
),
)
if osp.exists(path):
self.print_fn("Load existing processed {} files".format(split_type))
graph_list, data = dgl.load_graphs(path)
dataset = GraphDataSet(graph_list, data['labels'])
dataset = GraphDataSet(graph_list, data["labels"])
else:
self.print_fn("Processed {} files not exist.".format(split_type))
......@@ -254,6 +298,6 @@ class SEALData(object):
graph_list, labels = self.sampler(edges, labels)
dataset = GraphDataSet(graph_list, labels)
dgl.save_graphs(path, graph_list, {'labels': labels})
dgl.save_graphs(path, graph_list, {"labels": labels})
self.print_fn("Save preprocessed subgraph to {}".format(path))
return dataset
import argparse
from scipy.sparse.csgraph import shortest_path
import numpy as np
import pandas as pd
import torch
import dgl
from ogb.linkproppred import DglLinkPropPredDataset, Evaluator
from scipy.sparse.csgraph import shortest_path
import dgl
def parse_arguments():
"""
Parse arguments
"""
parser = argparse.ArgumentParser(description='SEAL')
parser.add_argument('--dataset', type=str, default='ogbl-collab')
parser.add_argument('--gpu_id', type=int, default=0)
parser.add_argument('--hop', type=int, default=1)
parser.add_argument('--model', type=str, default='dgcnn')
parser.add_argument('--gcn_type', type=str, default='gcn')
parser.add_argument('--num_layers', type=int, default=3)
parser.add_argument('--hidden_units', type=int, default=32)
parser.add_argument('--sort_k', type=int, default=30)
parser.add_argument('--pooling', type=str, default='sum')
parser.add_argument('--dropout', type=str, default=0.5)
parser.add_argument('--hits_k', type=int, default=50)
parser.add_argument('--lr', type=float, default=0.0001)
parser.add_argument('--neg_samples', type=int, default=1)
parser.add_argument('--subsample_ratio', type=float, default=0.1)
parser.add_argument('--epochs', type=int, default=60)
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--eval_steps', type=int, default=5)
parser.add_argument('--num_workers', type=int, default=32)
parser.add_argument('--random_seed', type=int, default=2021)
parser.add_argument('--save_dir', type=str, default='./processed')
parser = argparse.ArgumentParser(description="SEAL")
parser.add_argument("--dataset", type=str, default="ogbl-collab")
parser.add_argument("--gpu_id", type=int, default=0)
parser.add_argument("--hop", type=int, default=1)
parser.add_argument("--model", type=str, default="dgcnn")
parser.add_argument("--gcn_type", type=str, default="gcn")
parser.add_argument("--num_layers", type=int, default=3)
parser.add_argument("--hidden_units", type=int, default=32)
parser.add_argument("--sort_k", type=int, default=30)
parser.add_argument("--pooling", type=str, default="sum")
parser.add_argument("--dropout", type=str, default=0.5)
parser.add_argument("--hits_k", type=int, default=50)
parser.add_argument("--lr", type=float, default=0.0001)
parser.add_argument("--neg_samples", type=int, default=1)
parser.add_argument("--subsample_ratio", type=float, default=0.1)
parser.add_argument("--epochs", type=int, default=60)
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--eval_steps", type=int, default=5)
parser.add_argument("--num_workers", type=int, default=32)
parser.add_argument("--random_seed", type=int, default=2021)
parser.add_argument("--save_dir", type=str, default="./processed")
args = parser.parse_args()
return args
......@@ -79,11 +81,15 @@ def drnl_node_labeling(subgraph, src, dst):
idx = list(range(dst)) + list(range(dst + 1, adj.shape[0]))
adj_wo_dst = adj[idx, :][:, idx]
dist2src = shortest_path(adj_wo_dst, directed=False, unweighted=True, indices=src)
dist2src = shortest_path(
adj_wo_dst, directed=False, unweighted=True, indices=src
)
dist2src = np.insert(dist2src, dst, 0, axis=0)
dist2src = torch.from_numpy(dist2src)
dist2dst = shortest_path(adj_wo_src, directed=False, unweighted=True, indices=dst - 1)
dist2dst = shortest_path(
adj_wo_src, directed=False, unweighted=True, indices=dst - 1
)
dist2dst = np.insert(dist2dst, src, 0, axis=0)
dist2dst = torch.from_numpy(dist2dst)
......@@ -92,9 +98,9 @@ def drnl_node_labeling(subgraph, src, dst):
z = 1 + torch.min(dist2src, dist2dst)
z += dist_over_2 * (dist_over_2 + dist_mod_2 - 1)
z[src] = 1.
z[dst] = 1.
z[torch.isnan(z)] = 0.
z[src] = 1.0
z[dst] = 1.0
z[torch.isnan(z)] = 0.0
return z.to(torch.long)
......@@ -115,9 +121,11 @@ def evaluate_hits(name, pos_pred, neg_pred, K):
"""
evaluator = Evaluator(name)
evaluator.K = K
hits = evaluator.eval({
'y_pred_pos': pos_pred,
'y_pred_neg': neg_pred,
})[f'hits@{K}']
hits = evaluator.eval(
{
"y_pred_pos": pos_pred,
"y_pred_neg": neg_pred,
}
)[f"hits@{K}"]
return hits
......@@ -5,37 +5,46 @@ Paper: https://arxiv.org/abs/1902.07153
Code: https://github.com/Tiiiger/SGC
SGC implementation in DGL.
"""
import argparse, time, math
import argparse
import math
import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl.function as fn
import dgl
from dgl.data import register_data_args
from dgl.data import CoraGraphDataset, CiteseerGraphDataset, PubmedGraphDataset
import dgl.function as fn
from dgl.data import (
CiteseerGraphDataset,
CoraGraphDataset,
PubmedGraphDataset,
register_data_args,
)
from dgl.nn.pytorch.conv import SGConv
def evaluate(model, g, features, labels, mask):
model.eval()
with torch.no_grad():
logits = model(g, features)[mask] # only compute the evaluation set
logits = model(g, features)[mask] # only compute the evaluation set
labels = labels[mask]
_, indices = torch.max(logits, dim=1)
correct = torch.sum(indices == labels)
return correct.item() * 1.0 / len(labels)
def main(args):
# load and preprocess dataset
if args.dataset == 'cora':
if args.dataset == "cora":
data = CoraGraphDataset()
elif args.dataset == 'citeseer':
elif args.dataset == "citeseer":
data = CiteseerGraphDataset()
elif args.dataset == 'pubmed':
elif args.dataset == "pubmed":
data = PubmedGraphDataset()
else:
raise ValueError('Unknown dataset: {}'.format(args.dataset))
raise ValueError("Unknown dataset: {}".format(args.dataset))
g = data[0]
if args.gpu < 0:
......@@ -44,24 +53,29 @@ def main(args):
cuda = True
g = g.int().to(args.gpu)
features = g.ndata['feat']
labels = g.ndata['label']
train_mask = g.ndata['train_mask']
val_mask = g.ndata['val_mask']
test_mask = g.ndata['test_mask']
features = g.ndata["feat"]
labels = g.ndata["label"]
train_mask = g.ndata["train_mask"]
val_mask = g.ndata["val_mask"]
test_mask = g.ndata["test_mask"]
in_feats = features.shape[1]
n_classes = data.num_labels
n_edges = g.number_of_edges()
print("""----Data statistics------'
print(
"""----Data statistics------'
#Edges %d
#Classes %d
#Train samples %d
#Val samples %d
#Test samples %d""" %
(n_edges, n_classes,
train_mask.int().sum().item(),
val_mask.int().sum().item(),
test_mask.int().sum().item()))
#Test samples %d"""
% (
n_edges,
n_classes,
train_mask.int().sum().item(),
val_mask.int().sum().item(),
test_mask.int().sum().item(),
)
)
n_edges = g.number_of_edges()
# add self loop
......@@ -69,20 +83,16 @@ def main(args):
g = dgl.add_self_loop(g)
# create SGC model
model = SGConv(in_feats,
n_classes,
k=2,
cached=True,
bias=args.bias)
model = SGConv(in_feats, n_classes, k=2, cached=True, bias=args.bias)
if cuda:
model.cuda()
loss_fcn = torch.nn.CrossEntropyLoss()
# use optimizer
optimizer = torch.optim.Adam(model.parameters(),
lr=args.lr,
weight_decay=args.weight_decay)
optimizer = torch.optim.Adam(
model.parameters(), lr=args.lr, weight_decay=args.weight_decay
)
# initialize graph
dur = []
......@@ -91,7 +101,7 @@ def main(args):
if epoch >= 3:
t0 = time.time()
# forward
logits = model(g, features) # only compute the train set
logits = model(g, features) # only compute the train set
loss = loss_fcn(logits[train_mask], labels[train_mask])
optimizer.zero_grad()
......@@ -102,28 +112,36 @@ def main(args):
dur.append(time.time() - t0)
acc = evaluate(model, g, features, labels, val_mask)
print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | "
"ETputs(KTEPS) {:.2f}". format(epoch, np.mean(dur), loss.item(),
acc, n_edges / np.mean(dur) / 1000))
print(
"Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | "
"ETputs(KTEPS) {:.2f}".format(
epoch,
np.mean(dur),
loss.item(),
acc,
n_edges / np.mean(dur) / 1000,
)
)
print()
acc = evaluate(model, g, features, labels, test_mask)
print("Test Accuracy {:.4f}".format(acc))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='SGC')
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="SGC")
register_data_args(parser)
parser.add_argument("--gpu", type=int, default=-1,
help="gpu")
parser.add_argument("--lr", type=float, default=0.2,
help="learning rate")
parser.add_argument("--bias", action='store_true', default=False,
help="flag to use bias")
parser.add_argument("--n-epochs", type=int, default=100,
help="number of training epochs")
parser.add_argument("--weight-decay", type=float, default=5e-6,
help="Weight for L2 loss")
parser.add_argument("--gpu", type=int, default=-1, help="gpu")
parser.add_argument("--lr", type=float, default=0.2, help="learning rate")
parser.add_argument(
"--bias", action="store_true", default=False, help="flag to use bias"
)
parser.add_argument(
"--n-epochs", type=int, default=100, help="number of training epochs"
)
parser.add_argument(
"--weight-decay", type=float, default=5e-6, help="Weight for L2 loss"
)
args = parser.parse_args()
print(args)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment