"vscode:/vscode.git/clone" did not exist on "3a140cc0633bfaae0845ce7f7ca0efd8ab851591"
Unverified Commit a9f2acf3 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Black auto fix. (#4641)



* [Misc] Black auto fix.

* sort
Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent 08c50eb7
import argparse, time import argparse
import numpy as np import time
import networkx as nx
import mxnet as mx import mxnet as mx
import networkx as nx
import numpy as np
from mxnet import gluon from mxnet import gluon
from tagcn import TAGCN
import dgl import dgl
from dgl.data import register_data_args from dgl.data import (CiteseerGraphDataset, CoraGraphDataset,
from dgl.data import CoraGraphDataset, CiteseerGraphDataset, PubmedGraphDataset PubmedGraphDataset, register_data_args)
from tagcn import TAGCN
def evaluate(model, features, labels, mask): def evaluate(model, features, labels, mask):
pred = model(features).argmax(axis=1) pred = model(features).argmax(axis=1)
accuracy = ((pred == labels) * mask).sum() / mask.sum().asscalar() accuracy = ((pred == labels) * mask).sum() / mask.sum().asscalar()
return accuracy.asscalar() return accuracy.asscalar()
def main(args): def main(args):
# load and preprocess dataset # load and preprocess dataset
if args.dataset == 'cora': if args.dataset == "cora":
data = CoraGraphDataset() data = CoraGraphDataset()
elif args.dataset == 'citeseer': elif args.dataset == "citeseer":
data = CiteseerGraphDataset() data = CiteseerGraphDataset()
elif args.dataset == 'pubmed': elif args.dataset == "pubmed":
data = PubmedGraphDataset() data = PubmedGraphDataset()
else: else:
raise ValueError('Unknown dataset: {}'.format(args.dataset)) raise ValueError("Unknown dataset: {}".format(args.dataset))
g = data[0] g = data[0]
if args.gpu < 0: if args.gpu < 0:
...@@ -35,37 +38,44 @@ def main(args): ...@@ -35,37 +38,44 @@ def main(args):
ctx = mx.gpu(args.gpu) ctx = mx.gpu(args.gpu)
g = g.to(ctx) g = g.to(ctx)
features = g.ndata['feat'] features = g.ndata["feat"]
labels = mx.nd.array(g.ndata['label'], dtype="float32", ctx=ctx) labels = mx.nd.array(g.ndata["label"], dtype="float32", ctx=ctx)
train_mask = g.ndata['train_mask'] train_mask = g.ndata["train_mask"]
val_mask = g.ndata['val_mask'] val_mask = g.ndata["val_mask"]
test_mask = g.ndata['test_mask'] test_mask = g.ndata["test_mask"]
in_feats = features.shape[1] in_feats = features.shape[1]
n_classes = data.num_labels n_classes = data.num_labels
n_edges = data.graph.number_of_edges() n_edges = data.graph.number_of_edges()
print("""----Data statistics------' print(
"""----Data statistics------'
#Edges %d #Edges %d
#Classes %d #Classes %d
#Train samples %d #Train samples %d
#Val samples %d #Val samples %d
#Test samples %d""" % #Test samples %d"""
(n_edges, n_classes, % (
train_mask.sum().asscalar(), n_edges,
val_mask.sum().asscalar(), n_classes,
test_mask.sum().asscalar())) train_mask.sum().asscalar(),
val_mask.sum().asscalar(),
test_mask.sum().asscalar(),
)
)
# add self loop # add self loop
g = dgl.remove_self_loop(g) g = dgl.remove_self_loop(g)
g = dgl.add_self_loop(g) g = dgl.add_self_loop(g)
# create TAGCN model # create TAGCN model
model = TAGCN(g, model = TAGCN(
in_feats, g,
args.n_hidden, in_feats,
n_classes, args.n_hidden,
args.n_layers, n_classes,
mx.nd.relu, args.n_layers,
args.dropout) mx.nd.relu,
args.dropout,
)
model.initialize(ctx=ctx) model.initialize(ctx=ctx)
n_train_samples = train_mask.sum().asscalar() n_train_samples = train_mask.sum().asscalar()
...@@ -73,8 +83,11 @@ def main(args): ...@@ -73,8 +83,11 @@ def main(args):
# use optimizer # use optimizer
print(model.collect_params()) print(model.collect_params())
trainer = gluon.Trainer(model.collect_params(), 'adam', trainer = gluon.Trainer(
{'learning_rate': args.lr, 'wd': args.weight_decay}) model.collect_params(),
"adam",
{"learning_rate": args.lr, "wd": args.weight_decay},
)
# initialize graph # initialize graph
dur = [] dur = []
...@@ -94,33 +107,47 @@ def main(args): ...@@ -94,33 +107,47 @@ def main(args):
loss.asscalar() loss.asscalar()
dur.append(time.time() - t0) dur.append(time.time() - t0)
acc = evaluate(model, features, labels, val_mask) acc = evaluate(model, features, labels, val_mask)
print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | " print(
"ETputs(KTEPS) {:.2f}". format( "Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | "
epoch, np.mean(dur), loss.asscalar(), acc, n_edges / np.mean(dur) / 1000)) "ETputs(KTEPS) {:.2f}".format(
epoch,
np.mean(dur),
loss.asscalar(),
acc,
n_edges / np.mean(dur) / 1000,
)
)
print() print()
acc = evaluate(model, features, labels, val_mask) acc = evaluate(model, features, labels, val_mask)
print("Test accuracy {:.2%}".format(acc)) print("Test accuracy {:.2%}".format(acc))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='TAGCN') if __name__ == "__main__":
parser = argparse.ArgumentParser(description="TAGCN")
register_data_args(parser) register_data_args(parser)
parser.add_argument("--dropout", type=float, default=0.5, parser.add_argument(
help="dropout probability") "--dropout", type=float, default=0.5, help="dropout probability"
parser.add_argument("--gpu", type=int, default=-1, )
help="gpu") parser.add_argument("--gpu", type=int, default=-1, help="gpu")
parser.add_argument("--lr", type=float, default=1e-2, parser.add_argument("--lr", type=float, default=1e-2, help="learning rate")
help="learning rate") parser.add_argument(
parser.add_argument("--n-epochs", type=int, default=200, "--n-epochs", type=int, default=200, help="number of training epochs"
help="number of training epochs") )
parser.add_argument("--n-hidden", type=int, default=16, parser.add_argument(
help="number of hidden tagcn units") "--n-hidden", type=int, default=16, help="number of hidden tagcn units"
parser.add_argument("--n-layers", type=int, default=1, )
help="number of hidden tagcn layers") parser.add_argument(
parser.add_argument("--weight-decay", type=float, default=5e-4, "--n-layers", type=int, default=1, help="number of hidden tagcn layers"
help="Weight for L2 loss") )
parser.add_argument("--self-loop", action='store_true', parser.add_argument(
help="graph self-loop (default=False)") "--weight-decay", type=float, default=5e-4, help="Weight for L2 loss"
)
parser.add_argument(
"--self-loop",
action="store_true",
help="graph self-loop (default=False)",
)
parser.set_defaults(self_loop=False) parser.set_defaults(self_loop=False)
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
......
import argparse import argparse
import collections
import os
import time import time
import warnings import warnings
import zipfile import zipfile
import os
import collections
os.environ['DGLBACKEND'] = 'mxnet' os.environ["DGLBACKEND"] = "mxnet"
os.environ['MXNET_GPU_MEM_POOL_TYPE'] = 'Round' os.environ["MXNET_GPU_MEM_POOL_TYPE"] = "Round"
import numpy as np
import mxnet as mx import mxnet as mx
import numpy as np
from mxnet import gluon from mxnet import gluon
from tree_lstm import TreeLSTM
import dgl import dgl
import dgl.data as data import dgl.data as data
from tree_lstm import TreeLSTM SSTBatch = collections.namedtuple(
"SSTBatch", ["graph", "mask", "wordid", "label"]
)
SSTBatch = collections.namedtuple('SSTBatch', ['graph', 'mask', 'wordid', 'label'])
def batcher(ctx): def batcher(ctx):
def batcher_dev(batch): def batcher_dev(batch):
batch_trees = dgl.batch(batch) batch_trees = dgl.batch(batch)
return SSTBatch(graph=batch_trees, return SSTBatch(
mask=batch_trees.ndata['mask'].as_in_context(ctx), graph=batch_trees,
wordid=batch_trees.ndata['x'].as_in_context(ctx), mask=batch_trees.ndata["mask"].as_in_context(ctx),
label=batch_trees.ndata['y'].as_in_context(ctx)) wordid=batch_trees.ndata["x"].as_in_context(ctx),
label=batch_trees.ndata["y"].as_in_context(ctx),
)
return batcher_dev return batcher_dev
def prepare_glove(): def prepare_glove():
if not (os.path.exists('glove.840B.300d.txt') if not (
and data.utils.check_sha1('glove.840B.300d.txt', os.path.exists("glove.840B.300d.txt")
sha1_hash='294b9f37fa64cce31f9ebb409c266fc379527708')): and data.utils.check_sha1(
zip_path = data.utils.download('http://nlp.stanford.edu/data/glove.840B.300d.zip', "glove.840B.300d.txt",
sha1_hash='8084fbacc2dee3b1fd1ca4cc534cbfff3519ed0d') sha1_hash="294b9f37fa64cce31f9ebb409c266fc379527708",
with zipfile.ZipFile(zip_path, 'r') as zf: )
):
zip_path = data.utils.download(
"http://nlp.stanford.edu/data/glove.840B.300d.zip",
sha1_hash="8084fbacc2dee3b1fd1ca4cc534cbfff3519ed0d",
)
with zipfile.ZipFile(zip_path, "r") as zf:
zf.extractall() zf.extractall()
if not data.utils.check_sha1('glove.840B.300d.txt', if not data.utils.check_sha1(
sha1_hash='294b9f37fa64cce31f9ebb409c266fc379527708'): "glove.840B.300d.txt",
warnings.warn('The downloaded glove embedding file checksum mismatch. File content ' sha1_hash="294b9f37fa64cce31f9ebb409c266fc379527708",
'may be corrupted.') ):
warnings.warn(
"The downloaded glove embedding file checksum mismatch. File content "
"may be corrupted."
)
def main(args): def main(args):
np.random.seed(args.seed) np.random.seed(args.seed)
...@@ -53,7 +70,11 @@ def main(args): ...@@ -53,7 +70,11 @@ def main(args):
if args.gpu in mx.test_utils.list_gpus(): if args.gpu in mx.test_utils.list_gpus():
ctx = mx.gpu(args.gpu) ctx = mx.gpu(args.gpu)
else: else:
print('Requested GPU id {} was not found. Defaulting to CPU implementation'.format(args.gpu)) print(
"Requested GPU id {} was not found. Defaulting to CPU implementation".format(
args.gpu
)
)
ctx = mx.cpu() ctx = mx.cpu()
else: else:
ctx = mx.cpu() ctx = mx.cpu()
...@@ -62,45 +83,63 @@ def main(args): ...@@ -62,45 +83,63 @@ def main(args):
prepare_glove() prepare_glove()
trainset = data.SSTDataset() trainset = data.SSTDataset()
train_loader = gluon.data.DataLoader(dataset=trainset, train_loader = gluon.data.DataLoader(
batch_size=args.batch_size, dataset=trainset,
batchify_fn=batcher(ctx), batch_size=args.batch_size,
shuffle=True, batchify_fn=batcher(ctx),
num_workers=0) shuffle=True,
devset = data.SSTDataset(mode='dev') num_workers=0,
dev_loader = gluon.data.DataLoader(dataset=devset, )
batch_size=100, devset = data.SSTDataset(mode="dev")
batchify_fn=batcher(ctx), dev_loader = gluon.data.DataLoader(
shuffle=True, dataset=devset,
num_workers=0) batch_size=100,
batchify_fn=batcher(ctx),
testset = data.SSTDataset(mode='test') shuffle=True,
test_loader = gluon.data.DataLoader(dataset=testset, num_workers=0,
batch_size=100, )
batchify_fn=batcher(ctx),
shuffle=False, num_workers=0) testset = data.SSTDataset(mode="test")
test_loader = gluon.data.DataLoader(
model = TreeLSTM(trainset.vocab_size, dataset=testset,
args.x_size, batch_size=100,
args.h_size, batchify_fn=batcher(ctx),
trainset.num_classes, shuffle=False,
args.dropout, num_workers=0,
cell_type='childsum' if args.child_sum else 'nary', )
pretrained_emb = trainset.pretrained_emb,
ctx=ctx) model = TreeLSTM(
trainset.vocab_size,
args.x_size,
args.h_size,
trainset.num_classes,
args.dropout,
cell_type="childsum" if args.child_sum else "nary",
pretrained_emb=trainset.pretrained_emb,
ctx=ctx,
)
print(model) print(model)
params_ex_emb =[x for x in model.collect_params().values() params_ex_emb = [
if x.grad_req != 'null' and x.shape[0] != trainset.vocab_size] x
for x in model.collect_params().values()
if x.grad_req != "null" and x.shape[0] != trainset.vocab_size
]
params_emb = list(model.embedding.collect_params().values()) params_emb = list(model.embedding.collect_params().values())
for p in params_emb: for p in params_emb:
p.lr_mult = 0.1 p.lr_mult = 0.1
model.initialize(mx.init.Xavier(magnitude=1), ctx=ctx) model.initialize(mx.init.Xavier(magnitude=1), ctx=ctx)
model.hybridize() model.hybridize()
trainer = gluon.Trainer(model.collect_params('^(?!embedding).*$'), 'adagrad', trainer = gluon.Trainer(
{'learning_rate': args.lr, 'wd': args.weight_decay}) model.collect_params("^(?!embedding).*$"),
trainer_emb = gluon.Trainer(model.collect_params('^embedding.*$'), 'adagrad', "adagrad",
{'learning_rate': args.lr}) {"learning_rate": args.lr, "wd": args.weight_decay},
)
trainer_emb = gluon.Trainer(
model.collect_params("^embedding.*$"),
"adagrad",
{"learning_rate": args.lr},
)
dur = [] dur = []
L = gluon.loss.SoftmaxCrossEntropyLoss(axis=1) L = gluon.loss.SoftmaxCrossEntropyLoss(axis=1)
...@@ -114,7 +153,7 @@ def main(args): ...@@ -114,7 +153,7 @@ def main(args):
h = mx.nd.zeros((n, args.h_size), ctx=ctx) h = mx.nd.zeros((n, args.h_size), ctx=ctx)
c = mx.nd.zeros((n, args.h_size), ctx=ctx) c = mx.nd.zeros((n, args.h_size), ctx=ctx)
if step >= 3: if step >= 3:
t0 = time.time() # tik t0 = time.time() # tik
with mx.autograd.record(): with mx.autograd.record():
pred = model(batch, h, c) pred = model(batch, h, c)
loss = L(pred, batch.label) loss = L(pred, batch.label)
...@@ -124,17 +163,35 @@ def main(args): ...@@ -124,17 +163,35 @@ def main(args):
trainer_emb.step(args.batch_size) trainer_emb.step(args.batch_size)
if step >= 3: if step >= 3:
dur.append(time.time() - t0) # tok dur.append(time.time() - t0) # tok
if step > 0 and step % args.log_every == 0: if step > 0 and step % args.log_every == 0:
pred = pred.argmax(axis=1).astype(batch.label.dtype) pred = pred.argmax(axis=1).astype(batch.label.dtype)
acc = (batch.label == pred).sum() acc = (batch.label == pred).sum()
root_ids = [i for i in range(batch.graph.number_of_nodes()) if batch.graph.out_degree(i)==0] root_ids = [
root_acc = np.sum(batch.label.asnumpy()[root_ids] == pred.asnumpy()[root_ids]) i
for i in range(batch.graph.number_of_nodes())
if batch.graph.out_degree(i) == 0
]
root_acc = np.sum(
batch.label.asnumpy()[root_ids] == pred.asnumpy()[root_ids]
)
print("Epoch {:05d} | Step {:05d} | Loss {:.4f} | Acc {:.4f} | Root Acc {:.4f} | Time(s) {:.4f}".format( print(
epoch, step, loss.sum().asscalar(), 1.0*acc.asscalar()/len(batch.label), 1.0*root_acc/len(root_ids), np.mean(dur))) "Epoch {:05d} | Step {:05d} | Loss {:.4f} | Acc {:.4f} | Root Acc {:.4f} | Time(s) {:.4f}".format(
print('Epoch {:05d} training time {:.4f}s'.format(epoch, time.time() - t_epoch)) epoch,
step,
loss.sum().asscalar(),
1.0 * acc.asscalar() / len(batch.label),
1.0 * root_acc / len(root_ids),
np.mean(dur),
)
)
print(
"Epoch {:05d} training time {:.4f}s".format(
epoch, time.time() - t_epoch
)
)
# eval on dev set # eval on dev set
accs = [] accs = []
...@@ -148,31 +205,48 @@ def main(args): ...@@ -148,31 +205,48 @@ def main(args):
acc = (batch.label == pred).sum().asscalar() acc = (batch.label == pred).sum().asscalar()
accs.append([acc, len(batch.label)]) accs.append([acc, len(batch.label)])
root_ids = [i for i in range(batch.graph.number_of_nodes()) if batch.graph.out_degree(i)==0] root_ids = [
root_acc = np.sum(batch.label.asnumpy()[root_ids] == pred.asnumpy()[root_ids]) i
for i in range(batch.graph.number_of_nodes())
if batch.graph.out_degree(i) == 0
]
root_acc = np.sum(
batch.label.asnumpy()[root_ids] == pred.asnumpy()[root_ids]
)
root_accs.append([root_acc, len(root_ids)]) root_accs.append([root_acc, len(root_ids)])
dev_acc = 1.0*np.sum([x[0] for x in accs])/np.sum([x[1] for x in accs]) dev_acc = (
dev_root_acc = 1.0*np.sum([x[0] for x in root_accs])/np.sum([x[1] for x in root_accs]) 1.0 * np.sum([x[0] for x in accs]) / np.sum([x[1] for x in accs])
print("Epoch {:05d} | Dev Acc {:.4f} | Root Acc {:.4f}".format( )
epoch, dev_acc, dev_root_acc)) dev_root_acc = (
1.0
* np.sum([x[0] for x in root_accs])
/ np.sum([x[1] for x in root_accs])
)
print(
"Epoch {:05d} | Dev Acc {:.4f} | Root Acc {:.4f}".format(
epoch, dev_acc, dev_root_acc
)
)
if dev_root_acc > best_dev_acc: if dev_root_acc > best_dev_acc:
best_dev_acc = dev_root_acc best_dev_acc = dev_root_acc
best_epoch = epoch best_epoch = epoch
model.save_parameters('best_{}.params'.format(args.seed)) model.save_parameters("best_{}.params".format(args.seed))
else: else:
if best_epoch <= epoch - 10: if best_epoch <= epoch - 10:
break break
# lr decay # lr decay
trainer.set_learning_rate(max(1e-5, trainer.learning_rate*0.99)) trainer.set_learning_rate(max(1e-5, trainer.learning_rate * 0.99))
print(trainer.learning_rate) print(trainer.learning_rate)
trainer_emb.set_learning_rate(max(1e-5, trainer_emb.learning_rate*0.99)) trainer_emb.set_learning_rate(
max(1e-5, trainer_emb.learning_rate * 0.99)
)
print(trainer_emb.learning_rate) print(trainer_emb.learning_rate)
# test # test
model.load_parameters('best_{}.params'.format(args.seed)) model.load_parameters("best_{}.params".format(args.seed))
accs = [] accs = []
root_accs = [] root_accs = []
for step, batch in enumerate(test_loader): for step, batch in enumerate(test_loader):
...@@ -184,30 +258,46 @@ def main(args): ...@@ -184,30 +258,46 @@ def main(args):
acc = (batch.label == pred).sum().asscalar() acc = (batch.label == pred).sum().asscalar()
accs.append([acc, len(batch.label)]) accs.append([acc, len(batch.label)])
root_ids = [i for i in range(batch.graph.number_of_nodes()) if batch.graph.out_degree(i)==0] root_ids = [
root_acc = np.sum(batch.label.asnumpy()[root_ids] == pred.asnumpy()[root_ids]) i
for i in range(batch.graph.number_of_nodes())
if batch.graph.out_degree(i) == 0
]
root_acc = np.sum(
batch.label.asnumpy()[root_ids] == pred.asnumpy()[root_ids]
)
root_accs.append([root_acc, len(root_ids)]) root_accs.append([root_acc, len(root_ids)])
test_acc = 1.0*np.sum([x[0] for x in accs])/np.sum([x[1] for x in accs]) test_acc = 1.0 * np.sum([x[0] for x in accs]) / np.sum([x[1] for x in accs])
test_root_acc = 1.0*np.sum([x[0] for x in root_accs])/np.sum([x[1] for x in root_accs]) test_root_acc = (
print('------------------------------------------------------------------------------------') 1.0
print("Epoch {:05d} | Test Acc {:.4f} | Root Acc {:.4f}".format( * np.sum([x[0] for x in root_accs])
best_epoch, test_acc, test_root_acc)) / np.sum([x[1] for x in root_accs])
)
print(
"------------------------------------------------------------------------------------"
)
print(
"Epoch {:05d} | Test Acc {:.4f} | Root Acc {:.4f}".format(
best_epoch, test_acc, test_root_acc
)
)
if __name__ == '__main__': if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--gpu', type=int, default=0) parser.add_argument("--gpu", type=int, default=0)
parser.add_argument('--seed', type=int, default=41) parser.add_argument("--seed", type=int, default=41)
parser.add_argument('--batch-size', type=int, default=256) parser.add_argument("--batch-size", type=int, default=256)
parser.add_argument('--child-sum', action='store_true') parser.add_argument("--child-sum", action="store_true")
parser.add_argument('--x-size', type=int, default=300) parser.add_argument("--x-size", type=int, default=300)
parser.add_argument('--h-size', type=int, default=150) parser.add_argument("--h-size", type=int, default=150)
parser.add_argument('--epochs', type=int, default=100) parser.add_argument("--epochs", type=int, default=100)
parser.add_argument('--log-every', type=int, default=5) parser.add_argument("--log-every", type=int, default=5)
parser.add_argument('--lr', type=float, default=0.05) parser.add_argument("--lr", type=float, default=0.05)
parser.add_argument('--weight-decay', type=float, default=1e-4) parser.add_argument("--weight-decay", type=float, default=1e-4)
parser.add_argument('--dropout', type=float, default=0.5) parser.add_argument("--dropout", type=float, default=0.5)
parser.add_argument('--use-glove', action='store_true') parser.add_argument("--use-glove", action="store_true")
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
main(args) main(args)
...@@ -2,14 +2,17 @@ ...@@ -2,14 +2,17 @@
Improved Semantic Representations From Tree-Structured Long Short-Term Memory Networks Improved Semantic Representations From Tree-Structured Long Short-Term Memory Networks
https://arxiv.org/abs/1503.00075 https://arxiv.org/abs/1503.00075
""" """
import time
import itertools import itertools
import time
import mxnet as mx
import networkx as nx import networkx as nx
import numpy as np import numpy as np
import mxnet as mx
from mxnet import gluon from mxnet import gluon
import dgl import dgl
class _TreeLSTMCellNodeFunc(gluon.HybridBlock): class _TreeLSTMCellNodeFunc(gluon.HybridBlock):
def hybrid_forward(self, F, iou, b_iou, c): def hybrid_forward(self, F, iou, b_iou, c):
iou = F.broadcast_add(iou, b_iou) iou = F.broadcast_add(iou, b_iou)
...@@ -20,6 +23,7 @@ class _TreeLSTMCellNodeFunc(gluon.HybridBlock): ...@@ -20,6 +23,7 @@ class _TreeLSTMCellNodeFunc(gluon.HybridBlock):
return h, c return h, c
class _TreeLSTMCellReduceFunc(gluon.HybridBlock): class _TreeLSTMCellReduceFunc(gluon.HybridBlock):
def __init__(self, U_iou, U_f): def __init__(self, U_iou, U_f):
super(_TreeLSTMCellReduceFunc, self).__init__() super(_TreeLSTMCellReduceFunc, self).__init__()
...@@ -33,34 +37,39 @@ class _TreeLSTMCellReduceFunc(gluon.HybridBlock): ...@@ -33,34 +37,39 @@ class _TreeLSTMCellReduceFunc(gluon.HybridBlock):
iou = self.U_iou(h_cat) iou = self.U_iou(h_cat)
return iou, c return iou, c
class _TreeLSTMCell(gluon.HybridBlock): class _TreeLSTMCell(gluon.HybridBlock):
def __init__(self, h_size): def __init__(self, h_size):
super(_TreeLSTMCell, self).__init__() super(_TreeLSTMCell, self).__init__()
self._apply_node_func = _TreeLSTMCellNodeFunc() self._apply_node_func = _TreeLSTMCellNodeFunc()
self.b_iou = self.params.get('bias', shape=(1, 3 * h_size), self.b_iou = self.params.get(
init='zeros') "bias", shape=(1, 3 * h_size), init="zeros"
)
def message_func(self, edges): def message_func(self, edges):
return {'h': edges.src['h'], 'c': edges.src['c']} return {"h": edges.src["h"], "c": edges.src["c"]}
def apply_node_func(self, nodes): def apply_node_func(self, nodes):
iou = nodes.data['iou'] iou = nodes.data["iou"]
b_iou, c = self.b_iou.data(iou.context), nodes.data['c'] b_iou, c = self.b_iou.data(iou.context), nodes.data["c"]
h, c = self._apply_node_func(iou, b_iou, c) h, c = self._apply_node_func(iou, b_iou, c)
return {'h' : h, 'c' : c} return {"h": h, "c": c}
class TreeLSTMCell(_TreeLSTMCell): class TreeLSTMCell(_TreeLSTMCell):
def __init__(self, x_size, h_size): def __init__(self, x_size, h_size):
super(TreeLSTMCell, self).__init__(h_size) super(TreeLSTMCell, self).__init__(h_size)
self._reduce_func = _TreeLSTMCellReduceFunc( self._reduce_func = _TreeLSTMCellReduceFunc(
gluon.nn.Dense(3 * h_size, use_bias=False), gluon.nn.Dense(3 * h_size, use_bias=False),
gluon.nn.Dense(2 * h_size)) gluon.nn.Dense(2 * h_size),
)
self.W_iou = gluon.nn.Dense(3 * h_size, use_bias=False) self.W_iou = gluon.nn.Dense(3 * h_size, use_bias=False)
def reduce_func(self, nodes): def reduce_func(self, nodes):
h, c = nodes.mailbox['h'], nodes.mailbox['c'] h, c = nodes.mailbox["h"], nodes.mailbox["c"]
iou, c = self._reduce_func(h, c) iou, c = self._reduce_func(h, c)
return {'iou': iou, 'c': c} return {"iou": iou, "c": c}
class ChildSumTreeLSTMCell(_TreeLSTMCell): class ChildSumTreeLSTMCell(_TreeLSTMCell):
def __init__(self, x_size, h_size): def __init__(self, x_size, h_size):
...@@ -70,31 +79,34 @@ class ChildSumTreeLSTMCell(_TreeLSTMCell): ...@@ -70,31 +79,34 @@ class ChildSumTreeLSTMCell(_TreeLSTMCell):
self.U_f = gluon.nn.Dense(h_size) self.U_f = gluon.nn.Dense(h_size)
def reduce_func(self, nodes): def reduce_func(self, nodes):
h_tild = nodes.mailbox['h'].sum(axis=1) h_tild = nodes.mailbox["h"].sum(axis=1)
f = self.U_f(nodes.mailbox['h']).sigmoid() f = self.U_f(nodes.mailbox["h"]).sigmoid()
c = (f * nodes.mailbox['c']).sum(axis=1) c = (f * nodes.mailbox["c"]).sum(axis=1)
return {'iou': self.U_iou(h_tild), 'c': c} return {"iou": self.U_iou(h_tild), "c": c}
class TreeLSTM(gluon.nn.Block): class TreeLSTM(gluon.nn.Block):
def __init__(self, def __init__(
num_vocabs, self,
x_size, num_vocabs,
h_size, x_size,
num_classes, h_size,
dropout, num_classes,
cell_type='nary', dropout,
pretrained_emb=None, cell_type="nary",
ctx=None): pretrained_emb=None,
ctx=None,
):
super(TreeLSTM, self).__init__() super(TreeLSTM, self).__init__()
self.x_size = x_size self.x_size = x_size
self.embedding = gluon.nn.Embedding(num_vocabs, x_size) self.embedding = gluon.nn.Embedding(num_vocabs, x_size)
if pretrained_emb is not None: if pretrained_emb is not None:
print('Using glove') print("Using glove")
self.embedding.initialize(ctx=ctx) self.embedding.initialize(ctx=ctx)
self.embedding.weight.set_data(pretrained_emb) self.embedding.weight.set_data(pretrained_emb)
self.dropout = gluon.nn.Dropout(dropout) self.dropout = gluon.nn.Dropout(dropout)
self.linear = gluon.nn.Dense(num_classes) self.linear = gluon.nn.Dense(num_classes)
cell = TreeLSTMCell if cell_type == 'nary' else ChildSumTreeLSTMCell cell = TreeLSTMCell if cell_type == "nary" else ChildSumTreeLSTMCell
self.cell = cell(x_size, h_size) self.cell = cell(x_size, h_size)
self.ctx = ctx self.ctx = ctx
...@@ -118,15 +130,17 @@ class TreeLSTM(gluon.nn.Block): ...@@ -118,15 +130,17 @@ class TreeLSTM(gluon.nn.Block):
# feed embedding # feed embedding
embeds = self.embedding(batch.wordid * batch.mask) embeds = self.embedding(batch.wordid * batch.mask)
wiou = self.cell.W_iou(self.dropout(embeds)) wiou = self.cell.W_iou(self.dropout(embeds))
g.ndata['iou'] = wiou * batch.mask.expand_dims(-1).astype(wiou.dtype) g.ndata["iou"] = wiou * batch.mask.expand_dims(-1).astype(wiou.dtype)
g.ndata['h'] = h g.ndata["h"] = h
g.ndata['c'] = c g.ndata["c"] = c
# propagate # propagate
dgl.prop_nodes_topo(g, dgl.prop_nodes_topo(
message_func=self.cell.message_func, g,
reduce_func=self.cell.reduce_func, message_func=self.cell.message_func,
apply_node_func=self.cell.apply_node_func) reduce_func=self.cell.reduce_func,
apply_node_func=self.cell.apply_node_func,
)
# compute logits # compute logits
h = self.dropout(g.ndata.pop('h')) h = self.dropout(g.ndata.pop("h"))
logits = self.linear(h) logits = self.linear(h)
return logits return logits
...@@ -6,20 +6,23 @@ Paper: https://arxiv.org/abs/1810.05997 ...@@ -6,20 +6,23 @@ Paper: https://arxiv.org/abs/1810.05997
Author's code: https://github.com/klicperajo/ppnp Author's code: https://github.com/klicperajo/ppnp
""" """
import torch.nn as nn import torch.nn as nn
from dgl.nn.pytorch.conv import APPNPConv from dgl.nn.pytorch.conv import APPNPConv
class APPNP(nn.Module): class APPNP(nn.Module):
def __init__(self, def __init__(
g, self,
in_feats, g,
hiddens, in_feats,
n_classes, hiddens,
activation, n_classes,
feat_drop, activation,
edge_drop, feat_drop,
alpha, edge_drop,
k): alpha,
k,
):
super(APPNP, self).__init__() super(APPNP, self).__init__()
self.g = g self.g = g
self.layers = nn.ModuleList() self.layers = nn.ModuleList()
......
import argparse, time import argparse
import time
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from dgl.data import register_data_args
from dgl.data import CoraGraphDataset, CiteseerGraphDataset, PubmedGraphDataset
import dgl
from appnp import APPNP from appnp import APPNP
import dgl
from dgl.data import (CiteseerGraphDataset, CoraGraphDataset,
PubmedGraphDataset, register_data_args)
def evaluate(model, features, labels, mask): def evaluate(model, features, labels, mask):
model.eval() model.eval()
...@@ -22,14 +25,14 @@ def evaluate(model, features, labels, mask): ...@@ -22,14 +25,14 @@ def evaluate(model, features, labels, mask):
def main(args): def main(args):
# load and preprocess dataset # load and preprocess dataset
if args.dataset == 'cora': if args.dataset == "cora":
data = CoraGraphDataset() data = CoraGraphDataset()
elif args.dataset == 'citeseer': elif args.dataset == "citeseer":
data = CiteseerGraphDataset() data = CiteseerGraphDataset()
elif args.dataset == 'pubmed': elif args.dataset == "pubmed":
data = PubmedGraphDataset() data = PubmedGraphDataset()
else: else:
raise ValueError('Unknown dataset: {}'.format(args.dataset)) raise ValueError("Unknown dataset: {}".format(args.dataset))
g = data[0] g = data[0]
if args.gpu < 0: if args.gpu < 0:
...@@ -38,24 +41,29 @@ def main(args): ...@@ -38,24 +41,29 @@ def main(args):
cuda = True cuda = True
g = g.to(args.gpu) g = g.to(args.gpu)
features = g.ndata['feat'] features = g.ndata["feat"]
labels = g.ndata['label'] labels = g.ndata["label"]
train_mask = g.ndata['train_mask'] train_mask = g.ndata["train_mask"]
val_mask = g.ndata['val_mask'] val_mask = g.ndata["val_mask"]
test_mask = g.ndata['test_mask'] test_mask = g.ndata["test_mask"]
in_feats = features.shape[1] in_feats = features.shape[1]
n_classes = data.num_labels n_classes = data.num_labels
n_edges = g.number_of_edges() n_edges = g.number_of_edges()
print("""----Data statistics------' print(
"""----Data statistics------'
#Edges %d #Edges %d
#Classes %d #Classes %d
#Train samples %d #Train samples %d
#Val samples %d #Val samples %d
#Test samples %d""" % #Test samples %d"""
(n_edges, n_classes, % (
train_mask.int().sum().item(), n_edges,
val_mask.int().sum().item(), n_classes,
test_mask.int().sum().item())) train_mask.int().sum().item(),
val_mask.int().sum().item(),
test_mask.int().sum().item(),
)
)
n_edges = g.number_of_edges() n_edges = g.number_of_edges()
# add self loop # add self loop
...@@ -63,24 +71,26 @@ def main(args): ...@@ -63,24 +71,26 @@ def main(args):
g = dgl.add_self_loop(g) g = dgl.add_self_loop(g)
# create APPNP model # create APPNP model
model = APPNP(g, model = APPNP(
in_feats, g,
args.hidden_sizes, in_feats,
n_classes, args.hidden_sizes,
F.relu, n_classes,
args.in_drop, F.relu,
args.edge_drop, args.in_drop,
args.alpha, args.edge_drop,
args.k) args.alpha,
args.k,
)
if cuda: if cuda:
model.cuda() model.cuda()
loss_fcn = torch.nn.CrossEntropyLoss() loss_fcn = torch.nn.CrossEntropyLoss()
# use optimizer # use optimizer
optimizer = torch.optim.Adam(model.parameters(), optimizer = torch.optim.Adam(
lr=args.lr, model.parameters(), lr=args.lr, weight_decay=args.weight_decay
weight_decay=args.weight_decay) )
# initialize graph # initialize graph
dur = [] dur = []
...@@ -100,36 +110,52 @@ def main(args): ...@@ -100,36 +110,52 @@ def main(args):
dur.append(time.time() - t0) dur.append(time.time() - t0)
acc = evaluate(model, features, labels, val_mask) acc = evaluate(model, features, labels, val_mask)
print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | " print(
"ETputs(KTEPS) {:.2f}".format(epoch, np.mean(dur), loss.item(), "Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | "
acc, n_edges / np.mean(dur) / 1000)) "ETputs(KTEPS) {:.2f}".format(
epoch,
np.mean(dur),
loss.item(),
acc,
n_edges / np.mean(dur) / 1000,
)
)
print() print()
acc = evaluate(model, features, labels, test_mask) acc = evaluate(model, features, labels, test_mask)
print("Test Accuracy {:.4f}".format(acc)) print("Test Accuracy {:.4f}".format(acc))
if __name__ == '__main__': if __name__ == "__main__":
parser = argparse.ArgumentParser(description='APPNP') parser = argparse.ArgumentParser(description="APPNP")
register_data_args(parser) register_data_args(parser)
parser.add_argument("--in-drop", type=float, default=0.5, parser.add_argument(
help="input feature dropout") "--in-drop", type=float, default=0.5, help="input feature dropout"
parser.add_argument("--edge-drop", type=float, default=0.5, )
help="edge propagation dropout") parser.add_argument(
parser.add_argument("--gpu", type=int, default=-1, "--edge-drop", type=float, default=0.5, help="edge propagation dropout"
help="gpu") )
parser.add_argument("--lr", type=float, default=1e-2, parser.add_argument("--gpu", type=int, default=-1, help="gpu")
help="learning rate") parser.add_argument("--lr", type=float, default=1e-2, help="learning rate")
parser.add_argument("--n-epochs", type=int, default=200, parser.add_argument(
help="number of training epochs") "--n-epochs", type=int, default=200, help="number of training epochs"
parser.add_argument("--hidden_sizes", type=int, nargs='+', default=[64], )
help="hidden unit sizes for appnp") parser.add_argument(
parser.add_argument("--k", type=int, default=10, "--hidden_sizes",
help="Number of propagation steps") type=int,
parser.add_argument("--alpha", type=float, default=0.1, nargs="+",
help="Teleport Probability") default=[64],
parser.add_argument("--weight-decay", type=float, default=5e-4, help="hidden unit sizes for appnp",
help="Weight for L2 loss") )
parser.add_argument(
"--k", type=int, default=10, help="Number of propagation steps"
)
parser.add_argument(
"--alpha", type=float, default=0.1, help="Teleport Probability"
)
parser.add_argument(
"--weight-decay", type=float, default=5e-4, help="Weight for L2 loss"
)
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
......
...@@ -2,46 +2,52 @@ ...@@ -2,46 +2,52 @@
import argparse import argparse
import copy import copy
import numpy as np
import torch import torch
import torch.optim as optim
import torch.nn as nn import torch.nn as nn
import numpy as np import torch.optim as optim
from dgl.data import CoraGraphDataset, CiteseerGraphDataset, PubmedGraphDataset
from tqdm import trange
from model import ARMA4NC from model import ARMA4NC
from tqdm import trange
from dgl.data import CiteseerGraphDataset, CoraGraphDataset, PubmedGraphDataset
def main(args): def main(args):
# Step 1: Prepare graph data and retrieve train/validation/test index ============================= # # Step 1: Prepare graph data and retrieve train/validation/test index ============================= #
# Load from DGL dataset # Load from DGL dataset
if args.dataset == 'Cora': if args.dataset == "Cora":
dataset = CoraGraphDataset() dataset = CoraGraphDataset()
elif args.dataset == 'Citeseer': elif args.dataset == "Citeseer":
dataset = CiteseerGraphDataset() dataset = CiteseerGraphDataset()
elif args.dataset == 'Pubmed': elif args.dataset == "Pubmed":
dataset = PubmedGraphDataset() dataset = PubmedGraphDataset()
else: else:
raise ValueError('Dataset {} is invalid.'.format(args.dataset)) raise ValueError("Dataset {} is invalid.".format(args.dataset))
graph = dataset[0] graph = dataset[0]
# check cuda # check cuda
device = f'cuda:{args.gpu}' if args.gpu >= 0 and torch.cuda.is_available() else 'cpu' device = (
f"cuda:{args.gpu}"
if args.gpu >= 0 and torch.cuda.is_available()
else "cpu"
)
# retrieve the number of classes # retrieve the number of classes
n_classes = dataset.num_classes n_classes = dataset.num_classes
# retrieve labels of ground truth # retrieve labels of ground truth
labels = graph.ndata.pop('label').to(device).long() labels = graph.ndata.pop("label").to(device).long()
# Extract node features # Extract node features
feats = graph.ndata.pop('feat').to(device) feats = graph.ndata.pop("feat").to(device)
n_features = feats.shape[-1] n_features = feats.shape[-1]
# retrieve masks for train/validation/test # retrieve masks for train/validation/test
train_mask = graph.ndata.pop('train_mask') train_mask = graph.ndata.pop("train_mask")
val_mask = graph.ndata.pop('val_mask') val_mask = graph.ndata.pop("val_mask")
test_mask = graph.ndata.pop('test_mask') test_mask = graph.ndata.pop("test_mask")
train_idx = torch.nonzero(train_mask, as_tuple=False).squeeze().to(device) train_idx = torch.nonzero(train_mask, as_tuple=False).squeeze().to(device)
val_idx = torch.nonzero(val_mask, as_tuple=False).squeeze().to(device) val_idx = torch.nonzero(val_mask, as_tuple=False).squeeze().to(device)
...@@ -50,14 +56,16 @@ def main(args): ...@@ -50,14 +56,16 @@ def main(args):
graph = graph.to(device) graph = graph.to(device)
# Step 2: Create model =================================================================== # # Step 2: Create model =================================================================== #
model = ARMA4NC(in_dim=n_features, model = ARMA4NC(
hid_dim=args.hid_dim, in_dim=n_features,
out_dim=n_classes, hid_dim=args.hid_dim,
num_stacks=args.num_stacks, out_dim=n_classes,
num_layers=args.num_layers, num_stacks=args.num_stacks,
activation=nn.ReLU(), num_layers=args.num_layers,
dropout=args.dropout).to(device) activation=nn.ReLU(),
dropout=args.dropout,
).to(device)
best_model = copy.deepcopy(model) best_model = copy.deepcopy(model)
# Step 3: Create training components ===================================================== # # Step 3: Create training components ===================================================== #
...@@ -67,7 +75,7 @@ def main(args): ...@@ -67,7 +75,7 @@ def main(args):
# Step 4: training epoches =============================================================== # # Step 4: training epoches =============================================================== #
acc = 0 acc = 0
no_improvement = 0 no_improvement = 0
epochs = trange(args.epochs, desc='Accuracy & Loss') epochs = trange(args.epochs, desc="Accuracy & Loss")
for _ in epochs: for _ in epochs:
# Training using a full graph # Training using a full graph
...@@ -77,7 +85,9 @@ def main(args): ...@@ -77,7 +85,9 @@ def main(args):
# compute loss # compute loss
train_loss = loss_fn(logits[train_idx], labels[train_idx]) train_loss = loss_fn(logits[train_idx], labels[train_idx])
train_acc = torch.sum(logits[train_idx].argmax(dim=1) == labels[train_idx]).item() / len(train_idx) train_acc = torch.sum(
logits[train_idx].argmax(dim=1) == labels[train_idx]
).item() / len(train_idx)
# backward # backward
opt.zero_grad() opt.zero_grad()
...@@ -89,16 +99,21 @@ def main(args): ...@@ -89,16 +99,21 @@ def main(args):
with torch.no_grad(): with torch.no_grad():
valid_loss = loss_fn(logits[val_idx], labels[val_idx]) valid_loss = loss_fn(logits[val_idx], labels[val_idx])
valid_acc = torch.sum(logits[val_idx].argmax(dim=1) == labels[val_idx]).item() / len(val_idx) valid_acc = torch.sum(
logits[val_idx].argmax(dim=1) == labels[val_idx]
).item() / len(val_idx)
# Print out performance # Print out performance
epochs.set_description('Train Acc {:.4f} | Train Loss {:.4f} | Val Acc {:.4f} | Val loss {:.4f}'.format( epochs.set_description(
train_acc, train_loss.item(), valid_acc, valid_loss.item())) "Train Acc {:.4f} | Train Loss {:.4f} | Val Acc {:.4f} | Val loss {:.4f}".format(
train_acc, train_loss.item(), valid_acc, valid_loss.item()
)
)
if valid_acc < acc: if valid_acc < acc:
no_improvement += 1 no_improvement += 1
if no_improvement == args.early_stopping: if no_improvement == args.early_stopping:
print('Early stop.') print("Early stop.")
break break
else: else:
no_improvement = 0 no_improvement = 0
...@@ -107,31 +122,56 @@ def main(args): ...@@ -107,31 +122,56 @@ def main(args):
best_model.eval() best_model.eval()
logits = best_model(graph, feats) logits = best_model(graph, feats)
test_acc = torch.sum(logits[test_idx].argmax(dim=1) == labels[test_idx]).item() / len(test_idx) test_acc = torch.sum(
logits[test_idx].argmax(dim=1) == labels[test_idx]
).item() / len(test_idx)
print("Test Acc {:.4f}".format(test_acc)) print("Test Acc {:.4f}".format(test_acc))
return test_acc return test_acc
if __name__ == "__main__": if __name__ == "__main__":
""" """
ARMA Model Hyperparameters ARMA Model Hyperparameters
""" """
parser = argparse.ArgumentParser(description='ARMA GCN') parser = argparse.ArgumentParser(description="ARMA GCN")
# data source params # data source params
parser.add_argument('--dataset', type=str, default='Cora', help='Name of dataset.') parser.add_argument(
"--dataset", type=str, default="Cora", help="Name of dataset."
)
# cuda params # cuda params
parser.add_argument('--gpu', type=int, default=-1, help='GPU index. Default: -1, using CPU.') parser.add_argument(
"--gpu", type=int, default=-1, help="GPU index. Default: -1, using CPU."
)
# training params # training params
parser.add_argument('--epochs', type=int, default=2000, help='Training epochs.') parser.add_argument(
parser.add_argument('--early-stopping', type=int, default=100, help='Patient epochs to wait before early stopping.') "--epochs", type=int, default=2000, help="Training epochs."
parser.add_argument('--lr', type=float, default=0.01, help='Learning rate.') )
parser.add_argument('--lamb', type=float, default=5e-4, help='L2 reg.') parser.add_argument(
"--early-stopping",
type=int,
default=100,
help="Patient epochs to wait before early stopping.",
)
parser.add_argument("--lr", type=float, default=0.01, help="Learning rate.")
parser.add_argument("--lamb", type=float, default=5e-4, help="L2 reg.")
# model params # model params
parser.add_argument("--hid-dim", type=int, default=16, help='Hidden layer dimensionalities.') parser.add_argument(
parser.add_argument("--num-stacks", type=int, default=2, help='Number of K.') "--hid-dim", type=int, default=16, help="Hidden layer dimensionalities."
parser.add_argument("--num-layers", type=int, default=1, help='Number of T.') )
parser.add_argument("--dropout", type=float, default=0.75, help='Dropout applied at all layers.') parser.add_argument(
"--num-stacks", type=int, default=2, help="Number of K."
)
parser.add_argument(
"--num-layers", type=int, default=1, help="Number of T."
)
parser.add_argument(
"--dropout",
type=float,
default=0.75,
help="Dropout applied at all layers.",
)
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
...@@ -143,6 +183,6 @@ if __name__ == "__main__": ...@@ -143,6 +183,6 @@ if __name__ == "__main__":
mean = np.around(np.mean(acc_lists, axis=0), decimals=3) mean = np.around(np.mean(acc_lists, axis=0), decimals=3)
std = np.around(np.std(acc_lists, axis=0), decimals=3) std = np.around(np.std(acc_lists, axis=0), decimals=3)
print('Total acc: ', acc_lists) print("Total acc: ", acc_lists)
print('mean', mean) print("mean", mean)
print('std', std) print("std", std)
\ No newline at end of file
import math
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import dgl.function as fn import dgl.function as fn
import math
def glorot(tensor): def glorot(tensor):
if tensor is not None: if tensor is not None:
stdv = math.sqrt(6.0 / (tensor.size(-2) + tensor.size(-1))) stdv = math.sqrt(6.0 / (tensor.size(-2) + tensor.size(-1)))
tensor.data.uniform_(-stdv, stdv) tensor.data.uniform_(-stdv, stdv)
def zeros(tensor): def zeros(tensor):
if tensor is not None: if tensor is not None:
tensor.data.fill_(0) tensor.data.fill_(0)
class ARMAConv(nn.Module): class ARMAConv(nn.Module):
def __init__(self, def __init__(
in_dim, self,
out_dim, in_dim,
num_stacks, out_dim,
num_layers, num_stacks,
activation=None, num_layers,
dropout=0.0, activation=None,
bias=True): dropout=0.0,
bias=True,
):
super(ARMAConv, self).__init__() super(ARMAConv, self).__init__()
self.in_dim = in_dim self.in_dim = in_dim
self.out_dim = out_dim self.out_dim = out_dim
self.K = num_stacks self.K = num_stacks
...@@ -32,23 +39,34 @@ class ARMAConv(nn.Module): ...@@ -32,23 +39,34 @@ class ARMAConv(nn.Module):
self.dropout = nn.Dropout(p=dropout) self.dropout = nn.Dropout(p=dropout)
# init weight # init weight
self.w_0 = nn.ModuleDict({ self.w_0 = nn.ModuleDict(
str(k): nn.Linear(in_dim, out_dim, bias=False) for k in range(self.K) {
}) str(k): nn.Linear(in_dim, out_dim, bias=False)
for k in range(self.K)
}
)
# deeper weight # deeper weight
self.w = nn.ModuleDict({ self.w = nn.ModuleDict(
str(k): nn.Linear(out_dim, out_dim, bias=False) for k in range(self.K) {
}) str(k): nn.Linear(out_dim, out_dim, bias=False)
for k in range(self.K)
}
)
# v # v
self.v = nn.ModuleDict({ self.v = nn.ModuleDict(
str(k): nn.Linear(in_dim, out_dim, bias=False) for k in range(self.K) {
}) str(k): nn.Linear(in_dim, out_dim, bias=False)
for k in range(self.K)
}
)
# bias # bias
if bias: if bias:
self.bias = nn.Parameter(torch.Tensor(self.K, self.T, 1, self.out_dim)) self.bias = nn.Parameter(
torch.Tensor(self.K, self.T, 1, self.out_dim)
)
else: else:
self.register_parameter('bias', None) self.register_parameter("bias", None)
self.reset_parameters() self.reset_parameters()
def reset_parameters(self): def reset_parameters(self):
...@@ -64,59 +82,66 @@ class ARMAConv(nn.Module): ...@@ -64,59 +82,66 @@ class ARMAConv(nn.Module):
# assume that the graphs are undirected and graph.in_degrees() is the same as graph.out_degrees() # assume that the graphs are undirected and graph.in_degrees() is the same as graph.out_degrees()
degs = g.in_degrees().float().clamp(min=1) degs = g.in_degrees().float().clamp(min=1)
norm = torch.pow(degs, -0.5).to(feats.device).unsqueeze(1) norm = torch.pow(degs, -0.5).to(feats.device).unsqueeze(1)
output = [] output = []
for k in range(self.K): for k in range(self.K):
feats = init_feats feats = init_feats
for t in range(self.T): for t in range(self.T):
feats = feats * norm feats = feats * norm
g.ndata['h'] = feats g.ndata["h"] = feats
g.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h')) g.update_all(fn.copy_u("h", "m"), fn.sum("m", "h"))
feats = g.ndata.pop('h') feats = g.ndata.pop("h")
feats = feats * norm feats = feats * norm
if t == 0: if t == 0:
feats = self.w_0[str(k)](feats) feats = self.w_0[str(k)](feats)
else: else:
feats = self.w[str(k)](feats) feats = self.w[str(k)](feats)
feats += self.dropout(self.v[str(k)](init_feats)) feats += self.dropout(self.v[str(k)](init_feats))
feats += self.v[str(k)](self.dropout(init_feats)) feats += self.v[str(k)](self.dropout(init_feats))
if self.bias is not None: if self.bias is not None:
feats += self.bias[k][t] feats += self.bias[k][t]
if self.activation is not None: if self.activation is not None:
feats = self.activation(feats) feats = self.activation(feats)
output.append(feats) output.append(feats)
return torch.stack(output).mean(dim=0) return torch.stack(output).mean(dim=0)
class ARMA4NC(nn.Module): class ARMA4NC(nn.Module):
def __init__(self, def __init__(
in_dim, self,
hid_dim, in_dim,
out_dim, hid_dim,
num_stacks, out_dim,
num_layers, num_stacks,
activation=None, num_layers,
dropout=0.0): activation=None,
dropout=0.0,
):
super(ARMA4NC, self).__init__() super(ARMA4NC, self).__init__()
self.conv1 = ARMAConv(in_dim=in_dim, self.conv1 = ARMAConv(
out_dim=hid_dim, in_dim=in_dim,
num_stacks=num_stacks, out_dim=hid_dim,
num_layers=num_layers, num_stacks=num_stacks,
activation=activation, num_layers=num_layers,
dropout=dropout) activation=activation,
dropout=dropout,
self.conv2 = ARMAConv(in_dim=hid_dim, )
out_dim=out_dim,
num_stacks=num_stacks, self.conv2 = ARMAConv(
num_layers=num_layers, in_dim=hid_dim,
activation=activation, out_dim=out_dim,
dropout=dropout) num_stacks=num_stacks,
num_layers=num_layers,
activation=activation,
dropout=dropout,
)
self.dropout = nn.Dropout(p=dropout) self.dropout = nn.Dropout(p=dropout)
def forward(self, g, feats): def forward(self, g, feats):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment