Unverified Commit be8763fa authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Black auto fix. (#4679)


Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent eae6ce2a
import numpy as np
import torch as th import torch as th
import torch.nn as nn import torch.nn as nn
import numpy as np
class PositionalEncoding(nn.Module): class PositionalEncoding(nn.Module):
"Position Encoding module" "Position Encoding module"
def __init__(self, dim_model, dropout, max_len=5000): def __init__(self, dim_model, dropout, max_len=5000):
super(PositionalEncoding, self).__init__() super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout) self.dropout = nn.Dropout(p=dropout)
# Compute the positional encodings once in log space. # Compute the positional encodings once in log space.
pe = th.zeros(max_len, dim_model, dtype=th.float) pe = th.zeros(max_len, dim_model, dtype=th.float)
position = th.arange(0, max_len, dtype=th.float).unsqueeze(1) position = th.arange(0, max_len, dtype=th.float).unsqueeze(1)
div_term = th.exp(th.arange(0, dim_model, 2, dtype=th.float) * div_term = th.exp(
-(np.log(10000.0) / dim_model)) th.arange(0, dim_model, 2, dtype=th.float)
* -(np.log(10000.0) / dim_model)
)
pe[:, 0::2] = th.sin(position * div_term) pe[:, 0::2] = th.sin(position * div_term)
pe[:, 1::2] = th.cos(position * div_term) pe[:, 1::2] = th.cos(position * div_term)
pe = pe.unsqueeze(0) pe = pe.unsqueeze(0)
self.register_buffer('pe', pe) # Not a parameter but should be in state_dict self.register_buffer(
"pe", pe
) # Not a parameter but should be in state_dict
def forward(self, pos): def forward(self, pos):
return th.index_select(self.pe, 1, pos).squeeze(0) return th.index_select(self.pe, 1, pos).squeeze(0)
...@@ -23,6 +29,7 @@ class PositionalEncoding(nn.Module): ...@@ -23,6 +29,7 @@ class PositionalEncoding(nn.Module):
class Embeddings(nn.Module): class Embeddings(nn.Module):
"Word Embedding module" "Word Embedding module"
def __init__(self, vocab_size, dim_model): def __init__(self, vocab_size, dim_model):
super(Embeddings, self).__init__() super(Embeddings, self).__init__()
self.lut = nn.Embedding(vocab_size, dim_model) self.lut = nn.Embedding(vocab_size, dim_model)
......
import torch as th import torch as th
def src_dot_dst(src_field, dst_field, out_field): def src_dot_dst(src_field, dst_field, out_field):
""" """
This function serves as a surrogate for `src_dot_dst` built-in apply_edge function. This function serves as a surrogate for `src_dot_dst` built-in apply_edge function.
""" """
def func(edges): def func(edges):
return {out_field: (edges.src[src_field] * edges.dst[dst_field]).sum(-1, keepdim=True)} return {
out_field: (edges.src[src_field] * edges.dst[dst_field]).sum(
-1, keepdim=True
)
}
return func return func
def scaled_exp(field, c): def scaled_exp(field, c):
""" """
This function applies $exp(x / c)$ for input $x$, which is required by *Scaled Dot-Product Attention* mentioned in the paper. This function applies $exp(x / c)$ for input $x$, which is required by *Scaled Dot-Product Attention* mentioned in the paper.
""" """
def func(edges): def func(edges):
return {field: th.exp((edges.data[field] / c).clamp(-10, 10))} return {field: th.exp((edges.data[field] / c).clamp(-10, 10))}
return func return func
...@@ -2,26 +2,27 @@ import torch as th ...@@ -2,26 +2,27 @@ import torch as th
import torch.nn as nn import torch.nn as nn
from torch.nn import LayerNorm from torch.nn import LayerNorm
class Generator(nn.Module): class Generator(nn.Module):
''' """
Generate next token from the representation. This part is separated from the decoder, mostly for the convenience of sharing weight between embedding and generator. Generate next token from the representation. This part is separated from the decoder, mostly for the convenience of sharing weight between embedding and generator.
log(softmax(Wx + b)) log(softmax(Wx + b))
''' """
def __init__(self, dim_model, vocab_size): def __init__(self, dim_model, vocab_size):
super(Generator, self).__init__() super(Generator, self).__init__()
self.proj = nn.Linear(dim_model, vocab_size) self.proj = nn.Linear(dim_model, vocab_size)
def forward(self, x): def forward(self, x):
return th.log_softmax( return th.log_softmax(self.proj(x), dim=-1)
self.proj(x), dim=-1
)
class SubLayerWrapper(nn.Module): class SubLayerWrapper(nn.Module):
''' """
The module wraps normalization, dropout, residual connection into one equation: The module wraps normalization, dropout, residual connection into one equation:
sublayerwrapper(sublayer)(x) = x + dropout(sublayer(norm(x))) sublayerwrapper(sublayer)(x) = x + dropout(sublayer(norm(x)))
''' """
def __init__(self, size, dropout): def __init__(self, size, dropout):
super(SubLayerWrapper, self).__init__() super(SubLayerWrapper, self).__init__()
self.norm = LayerNorm(size) self.norm = LayerNorm(size)
...@@ -32,10 +33,11 @@ class SubLayerWrapper(nn.Module): ...@@ -32,10 +33,11 @@ class SubLayerWrapper(nn.Module):
class PositionwiseFeedForward(nn.Module): class PositionwiseFeedForward(nn.Module):
''' """
This module implements feed-forward network(after the Multi-Head Network) equation: This module implements feed-forward network(after the Multi-Head Network) equation:
FFN(x) = max(0, x @ W_1 + b_1) @ W_2 + b_2 FFN(x) = max(0, x @ W_1 + b_1) @ W_2 + b_2
''' """
def __init__(self, dim_model, dim_ff, dropout=0.1): def __init__(self, dim_model, dim_ff, dropout=0.1):
super(PositionwiseFeedForward, self).__init__() super(PositionwiseFeedForward, self).__init__()
self.w_1 = nn.Linear(dim_model, dim_ff) self.w_1 = nn.Linear(dim_model, dim_ff)
...@@ -47,16 +49,17 @@ class PositionwiseFeedForward(nn.Module): ...@@ -47,16 +49,17 @@ class PositionwiseFeedForward(nn.Module):
import copy import copy
def clones(module, k): def clones(module, k):
return nn.ModuleList( return nn.ModuleList(copy.deepcopy(module) for _ in range(k))
copy.deepcopy(module) for _ in range(k)
)
class EncoderLayer(nn.Module): class EncoderLayer(nn.Module):
def __init__(self, size, self_attn, feed_forward, dropout): def __init__(self, size, self_attn, feed_forward, dropout):
super(EncoderLayer, self).__init__() super(EncoderLayer, self).__init__()
self.size = size self.size = size
self.self_attn = self_attn # (key, query, value, mask) self.self_attn = self_attn # (key, query, value, mask)
self.feed_forward = feed_forward self.feed_forward = feed_forward
self.sublayer = clones(SubLayerWrapper(size, dropout), 2) self.sublayer = clones(SubLayerWrapper(size, dropout), 2)
...@@ -65,7 +68,7 @@ class DecoderLayer(nn.Module): ...@@ -65,7 +68,7 @@ class DecoderLayer(nn.Module):
def __init__(self, size, self_attn, src_attn, feed_forward, dropout): def __init__(self, size, self_attn, src_attn, feed_forward, dropout):
super(DecoderLayer, self).__init__() super(DecoderLayer, self).__init__()
self.size = size self.size = size
self.self_attn = self_attn # (key, query, value, mask) self.self_attn = self_attn # (key, query, value, mask)
self.src_attn = src_attn self.src_attn = src_attn
self.feed_forward = feed_forward self.feed_forward = feed_forward
self.sublayer = clones(SubLayerWrapper(size, dropout), 3) self.sublayer = clones(SubLayerWrapper(size, dropout), 3)
...@@ -15,19 +15,20 @@ class NoamOpt(object): ...@@ -15,19 +15,20 @@ class NoamOpt(object):
def rate(self, step=None): def rate(self, step=None):
if step is None: if step is None:
step = self._step step = self._step
return self.factor * \ return self.factor * (
(self.model_size ** (-0.5) * self.model_size ** (-0.5)
min(step ** (-0.5), step * self.warmup ** (-1.5)) * min(step ** (-0.5), step * self.warmup ** (-1.5))
) )
def step(self): def step(self):
self._step += 1 self._step += 1
rate = self.rate() rate = self.rate()
for p in self.optimizer.param_groups: for p in self.optimizer.param_groups:
p['lr'] = rate p["lr"] = rate
self._rate = rate self._rate = rate
self.optimizer.step() self.optimizer.step()
""" """
Default setting: Default setting:
......
# Beam Search Module # Beam Search Module
from modules import * import argparse
import numpy as n
from dataset import * from dataset import *
from modules import *
from tqdm import tqdm from tqdm import tqdm
import numpy as n
import argparse
k = 5 # Beam size k = 5 # Beam size
if __name__ == '__main__': if __name__ == "__main__":
argparser = argparse.ArgumentParser('testing translation model') argparser = argparse.ArgumentParser("testing translation model")
argparser.add_argument('--gpu', default=-1, help='gpu id') argparser.add_argument("--gpu", default=-1, help="gpu id")
argparser.add_argument('--N', default=6, type=int, help='num of layers') argparser.add_argument("--N", default=6, type=int, help="num of layers")
argparser.add_argument('--dataset', default='multi30k', help='dataset') argparser.add_argument("--dataset", default="multi30k", help="dataset")
argparser.add_argument('--batch', default=64, help='batch size') argparser.add_argument("--batch", default=64, help="batch size")
argparser.add_argument('--universal', action='store_true', help='use universal transformer') argparser.add_argument(
argparser.add_argument('--checkpoint', type=int, help='checkpoint: you must specify it') "--universal", action="store_true", help="use universal transformer"
argparser.add_argument('--print', action='store_true', help='whether to print translated text') )
argparser.add_argument(
"--checkpoint", type=int, help="checkpoint: you must specify it"
)
argparser.add_argument(
"--print", action="store_true", help="whether to print translated text"
)
args = argparser.parse_args() args = argparser.parse_args()
args_filter = ['batch', 'gpu', 'print'] args_filter = ["batch", "gpu", "print"]
exp_setting = '-'.join('{}'.format(v) for k, v in vars(args).items() if k not in args_filter) exp_setting = "-".join(
device = 'cpu' if args.gpu == -1 else 'cuda:{}'.format(args.gpu) "{}".format(v) for k, v in vars(args).items() if k not in args_filter
)
device = "cpu" if args.gpu == -1 else "cuda:{}".format(args.gpu)
dataset = get_dataset(args.dataset) dataset = get_dataset(args.dataset)
V = dataset.vocab_size V = dataset.vocab_size
dim_model = 512 dim_model = 512
fpred = open('pred.txt', 'w') fpred = open("pred.txt", "w")
fref = open('ref.txt', 'w') fref = open("ref.txt", "w")
graph_pool = GraphPool() graph_pool = GraphPool()
model = make_model(V, V, N=args.N, dim_model=dim_model) model = make_model(V, V, N=args.N, dim_model=dim_model)
with open('checkpoints/{}.pkl'.format(exp_setting), 'rb') as f: with open("checkpoints/{}.pkl".format(exp_setting), "rb") as f:
model.load_state_dict(th.load(f, map_location=lambda storage, loc: storage)) model.load_state_dict(
th.load(f, map_location=lambda storage, loc: storage)
)
model = model.to(device) model = model.to(device)
model.eval() model.eval()
test_iter = dataset(graph_pool, mode='test', batch_size=args.batch, device=device, k=k) test_iter = dataset(
graph_pool, mode="test", batch_size=args.batch, device=device, k=k
)
for i, g in enumerate(test_iter): for i, g in enumerate(test_iter):
with th.no_grad(): with th.no_grad():
output = model.infer(g, dataset.MAX_LENGTH, dataset.eos_id, k, alpha=0.6) output = model.infer(
g, dataset.MAX_LENGTH, dataset.eos_id, k, alpha=0.6
)
for line in dataset.get_sequence(output): for line in dataset.get_sequence(output):
if args.print: if args.print:
print(line) print(line)
print(line, file=fpred) print(line, file=fpred)
for line in dataset.tgt['test']: for line in dataset.tgt["test"]:
print(line.strip(), file=fref) print(line.strip(), file=fref)
fpred.close() fpred.close()
fref.close() fref.close()
os.system(r'bash scripts/bleu.sh pred.txt ref.txt') os.system(r"bash scripts/bleu.sh pred.txt ref.txt")
os.remove('pred.txt') os.remove("pred.txt")
os.remove('ref.txt') os.remove("ref.txt")
from modules import *
from loss import *
from optims import *
from dataset import *
from modules.config import *
import numpy as np
import argparse import argparse
import torch
from functools import partial from functools import partial
import numpy as np
import torch
import torch.distributed as dist import torch.distributed as dist
from dataset import *
from loss import *
from modules import *
from modules.config import *
from optims import *
def run_epoch(epoch, data_iter, dev_rank, ndev, model, loss_compute, is_train=True):
def run_epoch(
epoch, data_iter, dev_rank, ndev, model, loss_compute, is_train=True
):
universal = isinstance(model, UTransformer) universal = isinstance(model, UTransformer)
with loss_compute: with loss_compute:
for i, g in enumerate(data_iter): for i, g in enumerate(data_iter):
with T.set_grad_enabled(is_train): with T.set_grad_enabled(is_train):
if universal: if universal:
output, loss_act = model(g) output, loss_act = model(g)
if is_train: loss_act.backward(retain_graph=True) if is_train:
loss_act.backward(retain_graph=True)
else: else:
output = model(g) output = model(g)
tgt_y = g.tgt_y tgt_y = g.tgt_y
...@@ -25,29 +30,44 @@ def run_epoch(epoch, data_iter, dev_rank, ndev, model, loss_compute, is_train=Tr ...@@ -25,29 +30,44 @@ def run_epoch(epoch, data_iter, dev_rank, ndev, model, loss_compute, is_train=Tr
if universal: if universal:
for step in range(1, model.MAX_DEPTH + 1): for step in range(1, model.MAX_DEPTH + 1):
print("nodes entering step {}: {:.2f}%".format(step, (1.0 * model.stat[step] / model.stat[0]))) print(
"nodes entering step {}: {:.2f}%".format(
step, (1.0 * model.stat[step] / model.stat[0])
)
)
model.reset_stat() model.reset_stat()
print('Epoch {} {}: Dev {} average loss: {}, accuracy {}'.format( print(
epoch, "Training" if is_train else "Evaluating", "Epoch {} {}: Dev {} average loss: {}, accuracy {}".format(
dev_rank, loss_compute.avg_loss, loss_compute.accuracy)) epoch,
"Training" if is_train else "Evaluating",
dev_rank,
loss_compute.avg_loss,
loss_compute.accuracy,
)
)
def run(dev_id, args): def run(dev_id, args):
dist_init_method = 'tcp://{master_ip}:{master_port}'.format( dist_init_method = "tcp://{master_ip}:{master_port}".format(
master_ip=args.master_ip, master_port=args.master_port) master_ip=args.master_ip, master_port=args.master_port
)
world_size = args.ngpu world_size = args.ngpu
torch.distributed.init_process_group(backend="nccl", torch.distributed.init_process_group(
init_method=dist_init_method, backend="nccl",
world_size=world_size, init_method=dist_init_method,
rank=dev_id) world_size=world_size,
rank=dev_id,
)
gpu_rank = torch.distributed.get_rank() gpu_rank = torch.distributed.get_rank()
assert gpu_rank == dev_id assert gpu_rank == dev_id
main(dev_id, args) main(dev_id, args)
def main(dev_id, args): def main(dev_id, args):
if dev_id == -1: if dev_id == -1:
device = torch.device('cpu') device = torch.device("cpu")
else: else:
device = torch.device('cuda:{}'.format(dev_id)) device = torch.device("cuda:{}".format(dev_id))
# Set current device # Set current device
th.cuda.set_device(device) th.cuda.set_device(device)
# Prepare dataset # Prepare dataset
...@@ -58,8 +78,9 @@ def main(dev_id, args): ...@@ -58,8 +78,9 @@ def main(dev_id, args):
# Build graph pool # Build graph pool
graph_pool = GraphPool() graph_pool = GraphPool()
# Create model # Create model
model = make_model(V, V, N=args.N, dim_model=dim_model, model = make_model(
universal=args.universal) V, V, N=args.N, dim_model=dim_model, universal=args.universal
)
# Sharing weights between Encoder & Decoder # Sharing weights between Encoder & Decoder
model.src_embed.lut.weight = model.tgt_embed.lut.weight model.src_embed.lut.weight = model.tgt_embed.lut.weight
model.generator.proj.weight = model.tgt_embed.lut.weight model.generator.proj.weight = model.tgt_embed.lut.weight
...@@ -67,11 +88,12 @@ def main(dev_id, args): ...@@ -67,11 +88,12 @@ def main(dev_id, args):
model, criterion = model.to(device), criterion.to(device) model, criterion = model.to(device), criterion.to(device)
# Loss function # Loss function
if args.ngpu > 1: if args.ngpu > 1:
dev_rank = dev_id # current device id dev_rank = dev_id # current device id
ndev = args.ngpu # number of devices (including cpu) ndev = args.ngpu # number of devices (including cpu)
loss_compute = partial(MultiGPULossCompute, criterion, args.ngpu, loss_compute = partial(
args.grad_accum, model) MultiGPULossCompute, criterion, args.ngpu, args.grad_accum, model
else: # cpu or single gpu case )
else: # cpu or single gpu case
dev_rank = 0 dev_rank = 0
ndev = 1 ndev = 1
loss_compute = partial(SimpleLossCompute, criterion, args.grad_accum) loss_compute = partial(SimpleLossCompute, criterion, args.grad_accum)
...@@ -82,73 +104,134 @@ def main(dev_id, args): ...@@ -82,73 +104,134 @@ def main(dev_id, args):
param.data /= ndev param.data /= ndev
# Optimizer # Optimizer
model_opt = NoamOpt(dim_model, 0.1, 4000, model_opt = NoamOpt(
T.optim.Adam(model.parameters(), lr=1e-3, dim_model,
betas=(0.9, 0.98), eps=1e-9)) 0.1,
4000,
T.optim.Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.98), eps=1e-9),
)
# Train & evaluate # Train & evaluate
for epoch in range(100): for epoch in range(100):
start = time.time() start = time.time()
train_iter = dataset(graph_pool, mode='train', batch_size=args.batch, train_iter = dataset(
device=device, dev_rank=dev_rank, ndev=ndev) graph_pool,
mode="train",
batch_size=args.batch,
device=device,
dev_rank=dev_rank,
ndev=ndev,
)
model.train(True) model.train(True)
run_epoch(epoch, train_iter, dev_rank, ndev, model, run_epoch(
loss_compute(opt=model_opt), is_train=True) epoch,
train_iter,
dev_rank,
ndev,
model,
loss_compute(opt=model_opt),
is_train=True,
)
if dev_rank == 0: if dev_rank == 0:
model.att_weight_map = None model.att_weight_map = None
model.eval() model.eval()
valid_iter = dataset(graph_pool, mode='valid', batch_size=args.batch, valid_iter = dataset(
device=device, dev_rank=dev_rank, ndev=1) graph_pool,
run_epoch(epoch, valid_iter, dev_rank, 1, model, mode="valid",
loss_compute(opt=None), is_train=False) batch_size=args.batch,
device=device,
dev_rank=dev_rank,
ndev=1,
)
run_epoch(
epoch,
valid_iter,
dev_rank,
1,
model,
loss_compute(opt=None),
is_train=False,
)
end = time.time() end = time.time()
print("epoch time: {}".format(end - start)) print("epoch time: {}".format(end - start))
# Visualize attention # Visualize attention
if args.viz: if args.viz:
src_seq = dataset.get_seq_by_id(VIZ_IDX, mode='valid', field='src') src_seq = dataset.get_seq_by_id(
tgt_seq = dataset.get_seq_by_id(VIZ_IDX, mode='valid', field='tgt')[:-1] VIZ_IDX, mode="valid", field="src"
draw_atts(model.att_weight_map, src_seq, tgt_seq, exp_setting, 'epoch_{}'.format(epoch)) )
args_filter = ['batch', 'gpus', 'viz', 'master_ip', 'master_port', 'grad_accum', 'ngpu'] tgt_seq = dataset.get_seq_by_id(
exp_setting = '-'.join('{}'.format(v) for k, v in vars(args).items() if k not in args_filter) VIZ_IDX, mode="valid", field="tgt"
with open('checkpoints/{}-{}.pkl'.format(exp_setting, epoch), 'wb') as f: )[:-1]
draw_atts(
model.att_weight_map,
src_seq,
tgt_seq,
exp_setting,
"epoch_{}".format(epoch),
)
args_filter = [
"batch",
"gpus",
"viz",
"master_ip",
"master_port",
"grad_accum",
"ngpu",
]
exp_setting = "-".join(
"{}".format(v)
for k, v in vars(args).items()
if k not in args_filter
)
with open(
"checkpoints/{}-{}.pkl".format(exp_setting, epoch), "wb"
) as f:
torch.save(model.state_dict(), f) torch.save(model.state_dict(), f)
if __name__ == '__main__':
if not os.path.exists('checkpoints'): if __name__ == "__main__":
os.makedirs('checkpoints') if not os.path.exists("checkpoints"):
os.makedirs("checkpoints")
np.random.seed(1111) np.random.seed(1111)
argparser = argparse.ArgumentParser('training translation model') argparser = argparse.ArgumentParser("training translation model")
argparser.add_argument('--gpus', default='-1', type=str, help='gpu id') argparser.add_argument("--gpus", default="-1", type=str, help="gpu id")
argparser.add_argument('--N', default=6, type=int, help='enc/dec layers') argparser.add_argument("--N", default=6, type=int, help="enc/dec layers")
argparser.add_argument('--dataset', default='multi30k', help='dataset') argparser.add_argument("--dataset", default="multi30k", help="dataset")
argparser.add_argument('--batch', default=128, type=int, help='batch size') argparser.add_argument("--batch", default=128, type=int, help="batch size")
argparser.add_argument('--viz', action='store_true', argparser.add_argument(
help='visualize attention') "--viz", action="store_true", help="visualize attention"
argparser.add_argument('--universal', action='store_true', )
help='use universal transformer') argparser.add_argument(
argparser.add_argument('--master-ip', type=str, default='127.0.0.1', "--universal", action="store_true", help="use universal transformer"
help='master ip address') )
argparser.add_argument('--master-port', type=str, default='12345', argparser.add_argument(
help='master port') "--master-ip", type=str, default="127.0.0.1", help="master ip address"
argparser.add_argument('--grad-accum', type=int, default=1, )
help='accumulate gradients for this many times ' argparser.add_argument(
'then update weights') "--master-port", type=str, default="12345", help="master port"
)
argparser.add_argument(
"--grad-accum",
type=int,
default=1,
help="accumulate gradients for this many times " "then update weights",
)
args = argparser.parse_args() args = argparser.parse_args()
print(args) print(args)
devices = list(map(int, args.gpus.split(','))) devices = list(map(int, args.gpus.split(",")))
if len(devices) == 1: if len(devices) == 1:
args.ngpu = 0 if devices[0] < 0 else 1 args.ngpu = 0 if devices[0] < 0 else 1
main(devices[0], args) main(devices[0], args)
else: else:
args.ngpu = len(devices) args.ngpu = len(devices)
mp = torch.multiprocessing.get_context('spawn') mp = torch.multiprocessing.get_context("spawn")
procs = [] procs = []
for dev_id in devices: for dev_id in devices:
procs.append(mp.Process(target=run, args=(dev_id, args), procs.append(
daemon=True)) mp.Process(target=run, args=(dev_id, args), daemon=True)
)
procs[-1].start() procs[-1].start()
for p in procs: for p in procs:
p.join() p.join()
import argparse import argparse
import collections import collections
import time import time
import numpy as np import numpy as np
import torch as th import torch as th
import torch.nn.functional as F import torch.nn.functional as F
import torch.nn.init as INIT import torch.nn.init as INIT
import torch.optim as optim import torch.optim as optim
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from tree_lstm import TreeLSTM
import dgl import dgl
from dgl.data.tree import SSTDataset from dgl.data.tree import SSTDataset
from tree_lstm import TreeLSTM SSTBatch = collections.namedtuple(
"SSTBatch", ["graph", "mask", "wordid", "label"]
)
SSTBatch = collections.namedtuple('SSTBatch', ['graph', 'mask', 'wordid', 'label'])
def batcher(device): def batcher(device):
def batcher_dev(batch): def batcher_dev(batch):
batch_trees = dgl.batch(batch) batch_trees = dgl.batch(batch)
return SSTBatch(graph=batch_trees, return SSTBatch(
mask=batch_trees.ndata['mask'].to(device), graph=batch_trees,
wordid=batch_trees.ndata['x'].to(device), mask=batch_trees.ndata["mask"].to(device),
label=batch_trees.ndata['y'].to(device)) wordid=batch_trees.ndata["x"].to(device),
label=batch_trees.ndata["y"].to(device),
)
return batcher_dev return batcher_dev
def main(args): def main(args):
np.random.seed(args.seed) np.random.seed(args.seed)
th.manual_seed(args.seed) th.manual_seed(args.seed)
...@@ -33,45 +40,67 @@ def main(args): ...@@ -33,45 +40,67 @@ def main(args):
best_dev_acc = 0 best_dev_acc = 0
cuda = args.gpu >= 0 cuda = args.gpu >= 0
device = th.device('cuda:{}'.format(args.gpu)) if cuda else th.device('cpu') device = th.device("cuda:{}".format(args.gpu)) if cuda else th.device("cpu")
if cuda: if cuda:
th.cuda.set_device(args.gpu) th.cuda.set_device(args.gpu)
trainset = SSTDataset() trainset = SSTDataset()
train_loader = DataLoader(dataset=trainset, train_loader = DataLoader(
batch_size=args.batch_size, dataset=trainset,
collate_fn=batcher(device), batch_size=args.batch_size,
shuffle=True, collate_fn=batcher(device),
num_workers=0) shuffle=True,
devset = SSTDataset(mode='dev') num_workers=0,
dev_loader = DataLoader(dataset=devset, )
batch_size=100, devset = SSTDataset(mode="dev")
collate_fn=batcher(device), dev_loader = DataLoader(
shuffle=False, dataset=devset,
num_workers=0) batch_size=100,
collate_fn=batcher(device),
testset = SSTDataset(mode='test') shuffle=False,
test_loader = DataLoader(dataset=testset, num_workers=0,
batch_size=100, collate_fn=batcher(device), shuffle=False, num_workers=0) )
model = TreeLSTM(trainset.vocab_size, testset = SSTDataset(mode="test")
args.x_size, test_loader = DataLoader(
args.h_size, dataset=testset,
trainset.num_classes, batch_size=100,
args.dropout, collate_fn=batcher(device),
cell_type='childsum' if args.child_sum else 'nary', shuffle=False,
pretrained_emb = trainset.pretrained_emb).to(device) num_workers=0,
)
model = TreeLSTM(
trainset.vocab_size,
args.x_size,
args.h_size,
trainset.num_classes,
args.dropout,
cell_type="childsum" if args.child_sum else "nary",
pretrained_emb=trainset.pretrained_emb,
).to(device)
print(model) print(model)
params_ex_emb =[x for x in list(model.parameters()) if x.requires_grad and x.size(0)!=trainset.vocab_size] params_ex_emb = [
x
for x in list(model.parameters())
if x.requires_grad and x.size(0) != trainset.vocab_size
]
params_emb = list(model.embedding.parameters()) params_emb = list(model.embedding.parameters())
for p in params_ex_emb: for p in params_ex_emb:
if p.dim() > 1: if p.dim() > 1:
INIT.xavier_uniform_(p) INIT.xavier_uniform_(p)
optimizer = optim.Adagrad([ optimizer = optim.Adagrad(
{'params':params_ex_emb, 'lr':args.lr, 'weight_decay':args.weight_decay}, [
{'params':params_emb, 'lr':0.1*args.lr}]) {
"params": params_ex_emb,
"lr": args.lr,
"weight_decay": args.weight_decay,
},
{"params": params_emb, "lr": 0.1 * args.lr},
]
)
dur = [] dur = []
for epoch in range(args.epochs): for epoch in range(args.epochs):
...@@ -83,28 +112,47 @@ def main(args): ...@@ -83,28 +112,47 @@ def main(args):
h = th.zeros((n, args.h_size)).to(device) h = th.zeros((n, args.h_size)).to(device)
c = th.zeros((n, args.h_size)).to(device) c = th.zeros((n, args.h_size)).to(device)
if step >= 3: if step >= 3:
t0 = time.time() # tik t0 = time.time() # tik
logits = model(batch, g, h, c) logits = model(batch, g, h, c)
logp = F.log_softmax(logits, 1) logp = F.log_softmax(logits, 1)
loss = F.nll_loss(logp, batch.label, reduction='sum') loss = F.nll_loss(logp, batch.label, reduction="sum")
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
optimizer.step() optimizer.step()
if step >= 3: if step >= 3:
dur.append(time.time() - t0) # tok dur.append(time.time() - t0) # tok
if step > 0 and step % args.log_every == 0: if step > 0 and step % args.log_every == 0:
pred = th.argmax(logits, 1) pred = th.argmax(logits, 1)
acc = th.sum(th.eq(batch.label, pred)) acc = th.sum(th.eq(batch.label, pred))
root_ids = [i for i in range(g.number_of_nodes()) if g.out_degree(i)==0] root_ids = [
root_acc = np.sum(batch.label.cpu().data.numpy()[root_ids] == pred.cpu().data.numpy()[root_ids]) i
for i in range(g.number_of_nodes())
if g.out_degree(i) == 0
]
root_acc = np.sum(
batch.label.cpu().data.numpy()[root_ids]
== pred.cpu().data.numpy()[root_ids]
)
print("Epoch {:05d} | Step {:05d} | Loss {:.4f} | Acc {:.4f} | Root Acc {:.4f} | Time(s) {:.4f}".format( print(
epoch, step, loss.item(), 1.0*acc.item()/len(batch.label), 1.0*root_acc/len(root_ids), np.mean(dur))) "Epoch {:05d} | Step {:05d} | Loss {:.4f} | Acc {:.4f} | Root Acc {:.4f} | Time(s) {:.4f}".format(
print('Epoch {:05d} training time {:.4f}s'.format(epoch, time.time() - t_epoch)) epoch,
step,
loss.item(),
1.0 * acc.item() / len(batch.label),
1.0 * root_acc / len(root_ids),
np.mean(dur),
)
)
print(
"Epoch {:05d} training time {:.4f}s".format(
epoch, time.time() - t_epoch
)
)
# eval on dev set # eval on dev set
accs = [] accs = []
...@@ -121,30 +169,44 @@ def main(args): ...@@ -121,30 +169,44 @@ def main(args):
pred = th.argmax(logits, 1) pred = th.argmax(logits, 1)
acc = th.sum(th.eq(batch.label, pred)).item() acc = th.sum(th.eq(batch.label, pred)).item()
accs.append([acc, len(batch.label)]) accs.append([acc, len(batch.label)])
root_ids = [i for i in range(g.number_of_nodes()) if g.out_degree(i)==0] root_ids = [
root_acc = np.sum(batch.label.cpu().data.numpy()[root_ids] == pred.cpu().data.numpy()[root_ids]) i for i in range(g.number_of_nodes()) if g.out_degree(i) == 0
]
root_acc = np.sum(
batch.label.cpu().data.numpy()[root_ids]
== pred.cpu().data.numpy()[root_ids]
)
root_accs.append([root_acc, len(root_ids)]) root_accs.append([root_acc, len(root_ids)])
dev_acc = 1.0*np.sum([x[0] for x in accs])/np.sum([x[1] for x in accs]) dev_acc = (
dev_root_acc = 1.0*np.sum([x[0] for x in root_accs])/np.sum([x[1] for x in root_accs]) 1.0 * np.sum([x[0] for x in accs]) / np.sum([x[1] for x in accs])
print("Epoch {:05d} | Dev Acc {:.4f} | Root Acc {:.4f}".format( )
epoch, dev_acc, dev_root_acc)) dev_root_acc = (
1.0
* np.sum([x[0] for x in root_accs])
/ np.sum([x[1] for x in root_accs])
)
print(
"Epoch {:05d} | Dev Acc {:.4f} | Root Acc {:.4f}".format(
epoch, dev_acc, dev_root_acc
)
)
if dev_root_acc > best_dev_acc: if dev_root_acc > best_dev_acc:
best_dev_acc = dev_root_acc best_dev_acc = dev_root_acc
best_epoch = epoch best_epoch = epoch
th.save(model.state_dict(), 'best_{}.pkl'.format(args.seed)) th.save(model.state_dict(), "best_{}.pkl".format(args.seed))
else: else:
if best_epoch <= epoch - 10: if best_epoch <= epoch - 10:
break break
# lr decay # lr decay
for param_group in optimizer.param_groups: for param_group in optimizer.param_groups:
param_group['lr'] = max(1e-5, param_group['lr']*0.99) #10 param_group["lr"] = max(1e-5, param_group["lr"] * 0.99) # 10
print(param_group['lr']) print(param_group["lr"])
# test # test
model.load_state_dict(th.load('best_{}.pkl'.format(args.seed))) model.load_state_dict(th.load("best_{}.pkl".format(args.seed)))
accs = [] accs = []
root_accs = [] root_accs = []
model.eval() model.eval()
...@@ -159,29 +221,44 @@ def main(args): ...@@ -159,29 +221,44 @@ def main(args):
pred = th.argmax(logits, 1) pred = th.argmax(logits, 1)
acc = th.sum(th.eq(batch.label, pred)).item() acc = th.sum(th.eq(batch.label, pred)).item()
accs.append([acc, len(batch.label)]) accs.append([acc, len(batch.label)])
root_ids = [i for i in range(g.number_of_nodes()) if g.out_degree(i)==0] root_ids = [
root_acc = np.sum(batch.label.cpu().data.numpy()[root_ids] == pred.cpu().data.numpy()[root_ids]) i for i in range(g.number_of_nodes()) if g.out_degree(i) == 0
]
root_acc = np.sum(
batch.label.cpu().data.numpy()[root_ids]
== pred.cpu().data.numpy()[root_ids]
)
root_accs.append([root_acc, len(root_ids)]) root_accs.append([root_acc, len(root_ids)])
test_acc = 1.0*np.sum([x[0] for x in accs])/np.sum([x[1] for x in accs]) test_acc = 1.0 * np.sum([x[0] for x in accs]) / np.sum([x[1] for x in accs])
test_root_acc = 1.0*np.sum([x[0] for x in root_accs])/np.sum([x[1] for x in root_accs]) test_root_acc = (
print('------------------------------------------------------------------------------------') 1.0
print("Epoch {:05d} | Test Acc {:.4f} | Root Acc {:.4f}".format( * np.sum([x[0] for x in root_accs])
best_epoch, test_acc, test_root_acc)) / np.sum([x[1] for x in root_accs])
)
print(
"------------------------------------------------------------------------------------"
)
print(
"Epoch {:05d} | Test Acc {:.4f} | Root Acc {:.4f}".format(
best_epoch, test_acc, test_root_acc
)
)
if __name__ == '__main__': if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--gpu', type=int, default=-1) parser.add_argument("--gpu", type=int, default=-1)
parser.add_argument('--seed', type=int, default=41) parser.add_argument("--seed", type=int, default=41)
parser.add_argument('--batch-size', type=int, default=20) parser.add_argument("--batch-size", type=int, default=20)
parser.add_argument('--child-sum', action='store_true') parser.add_argument("--child-sum", action="store_true")
parser.add_argument('--x-size', type=int, default=300) parser.add_argument("--x-size", type=int, default=300)
parser.add_argument('--h-size', type=int, default=150) parser.add_argument("--h-size", type=int, default=150)
parser.add_argument('--epochs', type=int, default=100) parser.add_argument("--epochs", type=int, default=100)
parser.add_argument('--log-every', type=int, default=5) parser.add_argument("--log-every", type=int, default=5)
parser.add_argument('--lr', type=float, default=0.05) parser.add_argument("--lr", type=float, default=0.05)
parser.add_argument('--weight-decay', type=float, default=1e-4) parser.add_argument("--weight-decay", type=float, default=1e-4)
parser.add_argument('--dropout', type=float, default=0.5) parser.add_argument("--dropout", type=float, default=0.5)
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
main(args) main(args)
''' """
****************NOTE***************** ****************NOTE*****************
CREDITS : Thomas Kipf CREDITS : Thomas Kipf
since datasets are the same as those in kipf's implementation, since datasets are the same as those in kipf's implementation,
Their preprocessing source was used as-is. Their preprocessing source was used as-is.
************************************* *************************************
''' """
import numpy as np
import sys
import pickle as pkl import pickle as pkl
import sys
import networkx as nx import networkx as nx
import numpy as np
import scipy.sparse as sp import scipy.sparse as sp
...@@ -21,22 +22,26 @@ def parse_index_file(filename): ...@@ -21,22 +22,26 @@ def parse_index_file(filename):
def load_data(dataset): def load_data(dataset):
# load the data: x, tx, allx, graph # load the data: x, tx, allx, graph
names = ['x', 'tx', 'allx', 'graph'] names = ["x", "tx", "allx", "graph"]
objects = [] objects = []
for i in range(len(names)): for i in range(len(names)):
with open("data/ind.{}.{}".format(dataset, names[i]), 'rb') as f: with open("data/ind.{}.{}".format(dataset, names[i]), "rb") as f:
if sys.version_info > (3, 0): if sys.version_info > (3, 0):
objects.append(pkl.load(f, encoding='latin1')) objects.append(pkl.load(f, encoding="latin1"))
else: else:
objects.append(pkl.load(f)) objects.append(pkl.load(f))
x, tx, allx, graph = tuple(objects) x, tx, allx, graph = tuple(objects)
test_idx_reorder = parse_index_file("data/ind.{}.test.index".format(dataset)) test_idx_reorder = parse_index_file(
"data/ind.{}.test.index".format(dataset)
)
test_idx_range = np.sort(test_idx_reorder) test_idx_range = np.sort(test_idx_reorder)
if dataset == 'citeseer': if dataset == "citeseer":
# Fix citeseer dataset (there are some isolated nodes in the graph) # Fix citeseer dataset (there are some isolated nodes in the graph)
# Find isolated nodes, add them as zero-vecs into the right position # Find isolated nodes, add them as zero-vecs into the right position
test_idx_range_full = range(min(test_idx_reorder), max(test_idx_reorder) + 1) test_idx_range_full = range(
min(test_idx_reorder), max(test_idx_reorder) + 1
)
tx_extended = sp.lil_matrix((len(test_idx_range_full), x.shape[1])) tx_extended = sp.lil_matrix((len(test_idx_range_full), x.shape[1]))
tx_extended[test_idx_range - min(test_idx_range), :] = tx tx_extended[test_idx_range - min(test_idx_range), :] = tx
tx = tx_extended tx = tx_extended
......
from dgl.nn.pytorch import GraphConv
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from train import device from train import device
from dgl.nn.pytorch import GraphConv
class VGAEModel(nn.Module): class VGAEModel(nn.Module):
def __init__(self, in_dim, hidden1_dim, hidden2_dim): def __init__(self, in_dim, hidden1_dim, hidden2_dim):
...@@ -13,17 +13,38 @@ class VGAEModel(nn.Module): ...@@ -13,17 +13,38 @@ class VGAEModel(nn.Module):
self.hidden1_dim = hidden1_dim self.hidden1_dim = hidden1_dim
self.hidden2_dim = hidden2_dim self.hidden2_dim = hidden2_dim
layers = [GraphConv(self.in_dim, self.hidden1_dim, activation=F.relu, allow_zero_in_degree=True), layers = [
GraphConv(self.hidden1_dim, self.hidden2_dim, activation=lambda x: x, allow_zero_in_degree=True), GraphConv(
GraphConv(self.hidden1_dim, self.hidden2_dim, activation=lambda x: x, allow_zero_in_degree=True)] self.in_dim,
self.hidden1_dim,
activation=F.relu,
allow_zero_in_degree=True,
),
GraphConv(
self.hidden1_dim,
self.hidden2_dim,
activation=lambda x: x,
allow_zero_in_degree=True,
),
GraphConv(
self.hidden1_dim,
self.hidden2_dim,
activation=lambda x: x,
allow_zero_in_degree=True,
),
]
self.layers = nn.ModuleList(layers) self.layers = nn.ModuleList(layers)
def encoder(self, g, features): def encoder(self, g, features):
h = self.layers[0](g, features) h = self.layers[0](g, features)
self.mean = self.layers[1](g, h) self.mean = self.layers[1](g, h)
self.log_std = self.layers[2](g, h) self.log_std = self.layers[2](g, h)
gaussian_noise = torch.randn(features.size(0), self.hidden2_dim).to(device) gaussian_noise = torch.randn(features.size(0), self.hidden2_dim).to(
sampled_z = self.mean + gaussian_noise * torch.exp(self.log_std).to(device) device
)
sampled_z = self.mean + gaussian_noise * torch.exp(self.log_std).to(
device
)
return sampled_z return sampled_z
def decoder(self, z): def decoder(self, z):
......
...@@ -9,7 +9,9 @@ def mask_test_edges(adj): ...@@ -9,7 +9,9 @@ def mask_test_edges(adj):
# TODO: Clean up. # TODO: Clean up.
# Remove diagonal elements # Remove diagonal elements
adj = adj - sp.dia_matrix((adj.diagonal()[np.newaxis, :], [0]), shape=adj.shape) adj = adj - sp.dia_matrix(
(adj.diagonal()[np.newaxis, :], [0]), shape=adj.shape
)
adj.eliminate_zeros() adj.eliminate_zeros()
# Check that diag is zero: # Check that diag is zero:
assert np.diag(adj.todense()).sum() == 0 assert np.diag(adj.todense()).sum() == 0
...@@ -18,16 +20,18 @@ def mask_test_edges(adj): ...@@ -18,16 +20,18 @@ def mask_test_edges(adj):
adj_tuple = sparse_to_tuple(adj_triu) adj_tuple = sparse_to_tuple(adj_triu)
edges = adj_tuple[0] edges = adj_tuple[0]
edges_all = sparse_to_tuple(adj)[0] edges_all = sparse_to_tuple(adj)[0]
num_test = int(np.floor(edges.shape[0] / 10.)) num_test = int(np.floor(edges.shape[0] / 10.0))
num_val = int(np.floor(edges.shape[0] / 20.)) num_val = int(np.floor(edges.shape[0] / 20.0))
all_edge_idx = list(range(edges.shape[0])) all_edge_idx = list(range(edges.shape[0]))
np.random.shuffle(all_edge_idx) np.random.shuffle(all_edge_idx)
val_edge_idx = all_edge_idx[:num_val] val_edge_idx = all_edge_idx[:num_val]
test_edge_idx = all_edge_idx[num_val:(num_val + num_test)] test_edge_idx = all_edge_idx[num_val : (num_val + num_test)]
test_edges = edges[test_edge_idx] test_edges = edges[test_edge_idx]
val_edges = edges[val_edge_idx] val_edges = edges[val_edge_idx]
train_edges = np.delete(edges, np.hstack([test_edge_idx, val_edge_idx]), axis=0) train_edges = np.delete(
edges, np.hstack([test_edge_idx, val_edge_idx]), axis=0
)
def ismember(a, b, tol=5): def ismember(a, b, tol=5):
rows_close = np.all(np.round(a - b[:, None], tol) == 0, axis=-1) rows_close = np.all(np.round(a - b[:, None], tol) == 0, axis=-1)
...@@ -78,28 +82,39 @@ def mask_test_edges(adj): ...@@ -78,28 +82,39 @@ def mask_test_edges(adj):
data = np.ones(train_edges.shape[0]) data = np.ones(train_edges.shape[0])
# Re-build adj matrix # Re-build adj matrix
adj_train = sp.csr_matrix((data, (train_edges[:, 0], train_edges[:, 1])), shape=adj.shape) adj_train = sp.csr_matrix(
(data, (train_edges[:, 0], train_edges[:, 1])), shape=adj.shape
)
adj_train = adj_train + adj_train.T adj_train = adj_train + adj_train.T
# NOTE: these edge lists only contain single direction of edge! # NOTE: these edge lists only contain single direction of edge!
return adj_train, train_edges, val_edges, val_edges_false, test_edges, test_edges_false return (
adj_train,
train_edges,
val_edges,
val_edges_false,
test_edges,
test_edges_false,
)
def mask_test_edges_dgl(graph, adj): def mask_test_edges_dgl(graph, adj):
src, dst = graph.edges() src, dst = graph.edges()
edges_all = torch.stack([src, dst], dim=0) edges_all = torch.stack([src, dst], dim=0)
edges_all = edges_all.t().cpu().numpy() edges_all = edges_all.t().cpu().numpy()
num_test = int(np.floor(edges_all.shape[0] / 10.)) num_test = int(np.floor(edges_all.shape[0] / 10.0))
num_val = int(np.floor(edges_all.shape[0] / 20.)) num_val = int(np.floor(edges_all.shape[0] / 20.0))
all_edge_idx = list(range(edges_all.shape[0])) all_edge_idx = list(range(edges_all.shape[0]))
np.random.shuffle(all_edge_idx) np.random.shuffle(all_edge_idx)
val_edge_idx = all_edge_idx[:num_val] val_edge_idx = all_edge_idx[:num_val]
test_edge_idx = all_edge_idx[num_val:(num_val + num_test)] test_edge_idx = all_edge_idx[num_val : (num_val + num_test)]
train_edge_idx = all_edge_idx[(num_val + num_test):] train_edge_idx = all_edge_idx[(num_val + num_test) :]
test_edges = edges_all[test_edge_idx] test_edges = edges_all[test_edge_idx]
val_edges = edges_all[val_edge_idx] val_edges = edges_all[val_edge_idx]
train_edges = np.delete(edges_all, np.hstack([test_edge_idx, val_edge_idx]), axis=0) train_edges = np.delete(
edges_all, np.hstack([test_edge_idx, val_edge_idx]), axis=0
)
def ismember(a, b, tol=5): def ismember(a, b, tol=5):
rows_close = np.all(np.round(a - b[:, None], tol) == 0, axis=-1) rows_close = np.all(np.round(a - b[:, None], tol) == 0, axis=-1)
...@@ -148,7 +163,13 @@ def mask_test_edges_dgl(graph, adj): ...@@ -148,7 +163,13 @@ def mask_test_edges_dgl(graph, adj):
assert ~ismember(val_edges, test_edges) assert ~ismember(val_edges, test_edges)
# NOTE: these edge lists only contain single direction of edge! # NOTE: these edge lists only contain single direction of edge!
return train_edge_idx, val_edges, val_edges_false, test_edges, test_edges_false return (
train_edge_idx,
val_edges,
val_edges_false,
test_edges,
test_edges_false,
)
def sparse_to_tuple(sparse_mx): def sparse_to_tuple(sparse_mx):
...@@ -165,5 +186,10 @@ def preprocess_graph(adj): ...@@ -165,5 +186,10 @@ def preprocess_graph(adj):
adj_ = adj + sp.eye(adj.shape[0]) adj_ = adj + sp.eye(adj.shape[0])
rowsum = np.array(adj_.sum(1)) rowsum = np.array(adj_.sum(1))
degree_mat_inv_sqrt = sp.diags(np.power(rowsum, -0.5).flatten()) degree_mat_inv_sqrt = sp.diags(np.power(rowsum, -0.5).flatten())
adj_normalized = adj_.dot(degree_mat_inv_sqrt).transpose().dot(degree_mat_inv_sqrt).tocoo() adj_normalized = (
adj_.dot(degree_mat_inv_sqrt)
.transpose()
.dot(degree_mat_inv_sqrt)
.tocoo()
)
return adj_normalized, sparse_to_tuple(adj_normalized) return adj_normalized, sparse_to_tuple(adj_normalized)
...@@ -2,42 +2,77 @@ import argparse ...@@ -2,42 +2,77 @@ import argparse
import os import os
import time import time
import dgl import model
from dgl.data import CoraGraphDataset, CiteseerGraphDataset, PubmedGraphDataset
import numpy as np import numpy as np
import scipy.sparse as sp import scipy.sparse as sp
from sklearn.metrics import roc_auc_score, average_precision_score
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from input_data import load_data from input_data import load_data
import model from preprocess import (
from preprocess import mask_test_edges, mask_test_edges_dgl, sparse_to_tuple, preprocess_graph mask_test_edges,
mask_test_edges_dgl,
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' preprocess_graph,
sparse_to_tuple,
parser = argparse.ArgumentParser(description='Variant Graph Auto Encoder') )
parser.add_argument('--learning_rate', type=float, default=0.01, help='Initial learning rate.') from sklearn.metrics import average_precision_score, roc_auc_score
parser.add_argument('--epochs', '-e', type=int, default=200, help='Number of epochs to train.')
parser.add_argument('--hidden1', '-h1', type=int, default=32, help='Number of units in hidden layer 1.') import dgl
parser.add_argument('--hidden2', '-h2', type=int, default=16, help='Number of units in hidden layer 2.') from dgl.data import CiteseerGraphDataset, CoraGraphDataset, PubmedGraphDataset
parser.add_argument('--datasrc', '-s', type=str, default='dgl',
help='Dataset download from dgl Dataset or website.') os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
parser.add_argument('--dataset', '-d', type=str, default='cora', help='Dataset string.')
parser.add_argument('--gpu_id', type=int, default=0, help='GPU id to use.') parser = argparse.ArgumentParser(description="Variant Graph Auto Encoder")
parser.add_argument(
"--learning_rate", type=float, default=0.01, help="Initial learning rate."
)
parser.add_argument(
"--epochs", "-e", type=int, default=200, help="Number of epochs to train."
)
parser.add_argument(
"--hidden1",
"-h1",
type=int,
default=32,
help="Number of units in hidden layer 1.",
)
parser.add_argument(
"--hidden2",
"-h2",
type=int,
default=16,
help="Number of units in hidden layer 2.",
)
parser.add_argument(
"--datasrc",
"-s",
type=str,
default="dgl",
help="Dataset download from dgl Dataset or website.",
)
parser.add_argument(
"--dataset", "-d", type=str, default="cora", help="Dataset string."
)
parser.add_argument("--gpu_id", type=int, default=0, help="GPU id to use.")
args = parser.parse_args() args = parser.parse_args()
# check device # check device
device = torch.device("cuda:{}".format(args.gpu_id) if torch.cuda.is_available() else "cpu") device = torch.device(
"cuda:{}".format(args.gpu_id) if torch.cuda.is_available() else "cpu"
)
# device = "cpu" # device = "cpu"
# roc_means = [] # roc_means = []
# ap_means = [] # ap_means = []
def compute_loss_para(adj): def compute_loss_para(adj):
pos_weight = ((adj.shape[0] * adj.shape[0] - adj.sum()) / adj.sum()) pos_weight = (adj.shape[0] * adj.shape[0] - adj.sum()) / adj.sum()
norm = adj.shape[0] * adj.shape[0] / float((adj.shape[0] * adj.shape[0] - adj.sum()) * 2) norm = (
adj.shape[0]
* adj.shape[0]
/ float((adj.shape[0] * adj.shape[0] - adj.sum()) * 2)
)
weight_mask = adj.view(-1) == 1 weight_mask = adj.view(-1) == 1
weight_tensor = torch.ones(weight_mask.size(0)).to(device) weight_tensor = torch.ones(weight_mask.size(0)).to(device)
weight_tensor[weight_mask] = pos_weight weight_tensor[weight_mask] = pos_weight
...@@ -75,25 +110,31 @@ def get_scores(edges_pos, edges_neg, adj_rec): ...@@ -75,25 +110,31 @@ def get_scores(edges_pos, edges_neg, adj_rec):
def dgl_main(): def dgl_main():
# Load from DGL dataset # Load from DGL dataset
if args.dataset == 'cora': if args.dataset == "cora":
dataset = CoraGraphDataset(reverse_edge=False) dataset = CoraGraphDataset(reverse_edge=False)
elif args.dataset == 'citeseer': elif args.dataset == "citeseer":
dataset = CiteseerGraphDataset(reverse_edge=False) dataset = CiteseerGraphDataset(reverse_edge=False)
elif args.dataset == 'pubmed': elif args.dataset == "pubmed":
dataset = PubmedGraphDataset(reverse_edge=False) dataset = PubmedGraphDataset(reverse_edge=False)
else: else:
raise NotImplementedError raise NotImplementedError
graph = dataset[0] graph = dataset[0]
# Extract node features # Extract node features
feats = graph.ndata.pop('feat').to(device) feats = graph.ndata.pop("feat").to(device)
in_dim = feats.shape[-1] in_dim = feats.shape[-1]
# generate input # generate input
adj_orig = graph.adjacency_matrix().to_dense() adj_orig = graph.adjacency_matrix().to_dense()
# build test set with 10% positive links # build test set with 10% positive links
train_edge_idx, val_edges, val_edges_false, test_edges, test_edges_false = mask_test_edges_dgl(graph, adj_orig) (
train_edge_idx,
val_edges,
val_edges_false,
test_edges,
test_edges_false,
) = mask_test_edges_dgl(graph, adj_orig)
graph = graph.to(device) graph = graph.to(device)
...@@ -112,7 +153,10 @@ def dgl_main(): ...@@ -112,7 +153,10 @@ def dgl_main():
# create training component # create training component
optimizer = torch.optim.Adam(vgae_model.parameters(), lr=args.learning_rate) optimizer = torch.optim.Adam(vgae_model.parameters(), lr=args.learning_rate)
print('Total Parameters:', sum([p.nelement() for p in vgae_model.parameters()])) print(
"Total Parameters:",
sum([p.nelement() for p in vgae_model.parameters()]),
)
# create training epoch # create training epoch
for epoch in range(args.epochs): for epoch in range(args.epochs):
...@@ -124,10 +168,21 @@ def dgl_main(): ...@@ -124,10 +168,21 @@ def dgl_main():
logits = vgae_model.forward(graph, feats) logits = vgae_model.forward(graph, feats)
# compute loss # compute loss
loss = norm * F.binary_cross_entropy(logits.view(-1), adj.view(-1), weight=weight_tensor) loss = norm * F.binary_cross_entropy(
kl_divergence = 0.5 / logits.size(0) * ( logits.view(-1), adj.view(-1), weight=weight_tensor
1 + 2 * vgae_model.log_std - vgae_model.mean ** 2 - torch.exp(vgae_model.log_std) ** 2).sum( )
1).mean() kl_divergence = (
0.5
/ logits.size(0)
* (
1
+ 2 * vgae_model.log_std
- vgae_model.mean**2
- torch.exp(vgae_model.log_std) ** 2
)
.sum(1)
.mean()
)
loss -= kl_divergence loss -= kl_divergence
# backward # backward
...@@ -140,14 +195,31 @@ def dgl_main(): ...@@ -140,14 +195,31 @@ def dgl_main():
val_roc, val_ap = get_scores(val_edges, val_edges_false, logits) val_roc, val_ap = get_scores(val_edges, val_edges_false, logits)
# Print out performance # Print out performance
print("Epoch:", '%04d' % (epoch + 1), "train_loss=", "{:.5f}".format(loss.item()), "train_acc=", print(
"{:.5f}".format(train_acc), "val_roc=", "{:.5f}".format(val_roc), "val_ap=", "{:.5f}".format(val_ap), "Epoch:",
"time=", "{:.5f}".format(time.time() - t)) "%04d" % (epoch + 1),
"train_loss=",
"{:.5f}".format(loss.item()),
"train_acc=",
"{:.5f}".format(train_acc),
"val_roc=",
"{:.5f}".format(val_roc),
"val_ap=",
"{:.5f}".format(val_ap),
"time=",
"{:.5f}".format(time.time() - t),
)
test_roc, test_ap = get_scores(test_edges, test_edges_false, logits) test_roc, test_ap = get_scores(test_edges, test_edges_false, logits)
# roc_means.append(test_roc) # roc_means.append(test_roc)
# ap_means.append(test_ap) # ap_means.append(test_ap)
print("End of training!", "test_roc=", "{:.5f}".format(test_roc), "test_ap=", "{:.5f}".format(test_ap)) print(
"End of training!",
"test_roc=",
"{:.5f}".format(test_roc),
"test_ap=",
"{:.5f}".format(test_ap),
)
def web_main(): def web_main():
...@@ -157,10 +229,19 @@ def web_main(): ...@@ -157,10 +229,19 @@ def web_main():
# Store original adjacency matrix (without diagonal entries) for later # Store original adjacency matrix (without diagonal entries) for later
adj_orig = adj adj_orig = adj
adj_orig = adj_orig - sp.dia_matrix((adj_orig.diagonal()[np.newaxis, :], [0]), shape=adj_orig.shape) adj_orig = adj_orig - sp.dia_matrix(
(adj_orig.diagonal()[np.newaxis, :], [0]), shape=adj_orig.shape
)
adj_orig.eliminate_zeros() adj_orig.eliminate_zeros()
adj_train, train_edges, val_edges, val_edges_false, test_edges, test_edges_false = mask_test_edges(adj) (
adj_train,
train_edges,
val_edges,
val_edges_false,
test_edges,
test_edges_false,
) = mask_test_edges(adj)
adj = adj_train adj = adj_train
# # Create model # # Create model
...@@ -176,20 +257,30 @@ def web_main(): ...@@ -176,20 +257,30 @@ def web_main():
# Create Model # Create Model
pos_weight = float(adj.shape[0] * adj.shape[0] - adj.sum()) / adj.sum() pos_weight = float(adj.shape[0] * adj.shape[0] - adj.sum()) / adj.sum()
norm = adj.shape[0] * adj.shape[0] / float((adj.shape[0] * adj.shape[0] - adj.sum()) * 2) norm = (
adj.shape[0]
* adj.shape[0]
/ float((adj.shape[0] * adj.shape[0] - adj.sum()) * 2)
)
adj_label = adj_train + sp.eye(adj_train.shape[0]) adj_label = adj_train + sp.eye(adj_train.shape[0])
adj_label = sparse_to_tuple(adj_label) adj_label = sparse_to_tuple(adj_label)
adj_norm = torch.sparse.FloatTensor(torch.LongTensor(adj_norm[0].T), adj_norm = torch.sparse.FloatTensor(
torch.FloatTensor(adj_norm[1]), torch.LongTensor(adj_norm[0].T),
torch.Size(adj_norm[2])) torch.FloatTensor(adj_norm[1]),
adj_label = torch.sparse.FloatTensor(torch.LongTensor(adj_label[0].T), torch.Size(adj_norm[2]),
torch.FloatTensor(adj_label[1]), )
torch.Size(adj_label[2])) adj_label = torch.sparse.FloatTensor(
features = torch.sparse.FloatTensor(torch.LongTensor(features[0].T), torch.LongTensor(adj_label[0].T),
torch.FloatTensor(features[1]), torch.FloatTensor(adj_label[1]),
torch.Size(features[2])) torch.Size(adj_label[2]),
)
features = torch.sparse.FloatTensor(
torch.LongTensor(features[0].T),
torch.FloatTensor(features[1]),
torch.Size(features[2]),
)
weight_mask = adj_label.to_dense().view(-1) == 1 weight_mask = adj_label.to_dense().view(-1) == 1
weight_tensor = torch.ones(weight_mask.size(0)) weight_tensor = torch.ones(weight_mask.size(0))
...@@ -201,7 +292,10 @@ def web_main(): ...@@ -201,7 +292,10 @@ def web_main():
vgae_model = model.VGAEModel(in_dim, args.hidden1, args.hidden2) vgae_model = model.VGAEModel(in_dim, args.hidden1, args.hidden2)
# create training component # create training component
optimizer = torch.optim.Adam(vgae_model.parameters(), lr=args.learning_rate) optimizer = torch.optim.Adam(vgae_model.parameters(), lr=args.learning_rate)
print('Total Parameters:', sum([p.nelement() for p in vgae_model.parameters()])) print(
"Total Parameters:",
sum([p.nelement() for p in vgae_model.parameters()]),
)
def get_scores(edges_pos, edges_neg, adj_rec): def get_scores(edges_pos, edges_neg, adj_rec):
def sigmoid(x): def sigmoid(x):
...@@ -245,10 +339,21 @@ def web_main(): ...@@ -245,10 +339,21 @@ def web_main():
logits = vgae_model.forward(graph, features) logits = vgae_model.forward(graph, features)
# compute loss # compute loss
loss = norm * F.binary_cross_entropy(logits.view(-1), adj_label.to_dense().view(-1), weight=weight_tensor) loss = norm * F.binary_cross_entropy(
kl_divergence = 0.5 / logits.size(0) * ( logits.view(-1), adj_label.to_dense().view(-1), weight=weight_tensor
1 + 2 * vgae_model.log_std - vgae_model.mean ** 2 - torch.exp(vgae_model.log_std) ** 2).sum( )
1).mean() kl_divergence = (
0.5
/ logits.size(0)
* (
1
+ 2 * vgae_model.log_std
- vgae_model.mean**2
- torch.exp(vgae_model.log_std) ** 2
)
.sum(1)
.mean()
)
loss -= kl_divergence loss -= kl_divergence
# backward # backward
...@@ -261,12 +366,29 @@ def web_main(): ...@@ -261,12 +366,29 @@ def web_main():
val_roc, val_ap = get_scores(val_edges, val_edges_false, logits) val_roc, val_ap = get_scores(val_edges, val_edges_false, logits)
# Print out performance # Print out performance
print("Epoch:", '%04d' % (epoch + 1), "train_loss=", "{:.5f}".format(loss.item()), "train_acc=", print(
"{:.5f}".format(train_acc), "val_roc=", "{:.5f}".format(val_roc), "val_ap=", "{:.5f}".format(val_ap), "Epoch:",
"time=", "{:.5f}".format(time.time() - t)) "%04d" % (epoch + 1),
"train_loss=",
"{:.5f}".format(loss.item()),
"train_acc=",
"{:.5f}".format(train_acc),
"val_roc=",
"{:.5f}".format(val_roc),
"val_ap=",
"{:.5f}".format(val_ap),
"time=",
"{:.5f}".format(time.time() - t),
)
test_roc, test_ap = get_scores(test_edges, test_edges_false, logits) test_roc, test_ap = get_scores(test_edges, test_edges_false, logits)
print("End of training!", "test_roc=", "{:.5f}".format(test_roc), "test_ap=", "{:.5f}".format(test_ap)) print(
"End of training!",
"test_roc=",
"{:.5f}".format(test_roc),
"test_ap=",
"{:.5f}".format(test_ap),
)
# roc_means.append(test_roc) # roc_means.append(test_roc)
# ap_means.append(test_ap) # ap_means.append(test_ap)
...@@ -282,8 +404,8 @@ def web_main(): ...@@ -282,8 +404,8 @@ def web_main():
# print("roc_mean=", "{:.5f}".format(roc_mean), "roc_std=", "{:.5f}".format(roc_std), "ap_mean=", # print("roc_mean=", "{:.5f}".format(roc_mean), "roc_std=", "{:.5f}".format(roc_std), "ap_mean=",
# "{:.5f}".format(ap_mean), "ap_std=", "{:.5f}".format(ap_std)) # "{:.5f}".format(ap_mean), "ap_std=", "{:.5f}".format(ap_std))
if __name__ == '__main__': if __name__ == "__main__":
if args.datasrc == 'dgl': if args.datasrc == "dgl":
dgl_main() dgl_main()
elif args.datasrc == 'website': elif args.datasrc == "website":
web_main() web_main()
import dgl import argparse
import time
import numpy as np import numpy as np
import torch as th import torch as th
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.optim as optim import torch.optim as optim
import tqdm
from torch.utils.data import DataLoader
import dgl
import dgl.function as fn import dgl.function as fn
import dgl.nn.pytorch as dglnn import dgl.nn.pytorch as dglnn
import time
import argparse
import tqdm
from dgl.data import RedditDataset from dgl.data import RedditDataset
from torch.utils.data import DataLoader
class SAGEConvWithCV(nn.Module): class SAGEConvWithCV(nn.Module):
def __init__(self, in_feats, out_feats, activation): def __init__(self, in_feats, out_feats, activation):
...@@ -20,7 +23,7 @@ class SAGEConvWithCV(nn.Module): ...@@ -20,7 +23,7 @@ class SAGEConvWithCV(nn.Module):
self.reset_parameters() self.reset_parameters()
def reset_parameters(self): def reset_parameters(self):
gain = nn.init.calculate_gain('relu') gain = nn.init.calculate_gain("relu")
nn.init.xavier_uniform_(self.W.weight, gain=gain) nn.init.xavier_uniform_(self.W.weight, gain=gain)
nn.init.constant_(self.W.bias, 0) nn.init.constant_(self.W.bias, 0)
...@@ -29,10 +32,14 @@ class SAGEConvWithCV(nn.Module): ...@@ -29,10 +32,14 @@ class SAGEConvWithCV(nn.Module):
with block.local_scope(): with block.local_scope():
H_src, H_dst = H H_src, H_dst = H
HBar_src, agg_HBar_dst = HBar HBar_src, agg_HBar_dst = HBar
block.dstdata['agg_hbar'] = agg_HBar_dst block.dstdata["agg_hbar"] = agg_HBar_dst
block.srcdata['hdelta'] = H_src - HBar_src block.srcdata["hdelta"] = H_src - HBar_src
block.update_all(fn.copy_u('hdelta', 'm'), fn.mean('m', 'hdelta_new')) block.update_all(
h_neigh = block.dstdata['agg_hbar'] + block.dstdata['hdelta_new'] fn.copy_u("hdelta", "m"), fn.mean("m", "hdelta_new")
)
h_neigh = (
block.dstdata["agg_hbar"] + block.dstdata["hdelta_new"]
)
h = self.W(th.cat([H_dst, h_neigh], 1)) h = self.W(th.cat([H_dst, h_neigh], 1))
if self.activation is not None: if self.activation is not None:
h = self.activation(h) h = self.activation(h)
...@@ -40,21 +47,17 @@ class SAGEConvWithCV(nn.Module): ...@@ -40,21 +47,17 @@ class SAGEConvWithCV(nn.Module):
else: else:
with block.local_scope(): with block.local_scope():
H_src, H_dst = H H_src, H_dst = H
block.srcdata['h'] = H_src block.srcdata["h"] = H_src
block.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'h_new')) block.update_all(fn.copy_u("h", "m"), fn.mean("m", "h_new"))
h_neigh = block.dstdata['h_new'] h_neigh = block.dstdata["h_new"]
h = self.W(th.cat([H_dst, h_neigh], 1)) h = self.W(th.cat([H_dst, h_neigh], 1))
if self.activation is not None: if self.activation is not None:
h = self.activation(h) h = self.activation(h)
return h return h
class SAGE(nn.Module): class SAGE(nn.Module):
def __init__(self, def __init__(self, in_feats, n_hidden, n_classes, n_layers, activation):
in_feats,
n_hidden,
n_classes,
n_layers,
activation):
super().__init__() super().__init__()
self.n_layers = n_layers self.n_layers = n_layers
self.n_hidden = n_hidden self.n_hidden = n_hidden
...@@ -66,20 +69,20 @@ class SAGE(nn.Module): ...@@ -66,20 +69,20 @@ class SAGE(nn.Module):
self.layers.append(SAGEConvWithCV(n_hidden, n_classes, None)) self.layers.append(SAGEConvWithCV(n_hidden, n_classes, None))
def forward(self, blocks): def forward(self, blocks):
h = blocks[0].srcdata['features'] h = blocks[0].srcdata["features"]
updates = [] updates = []
for layer, block in zip(self.layers, blocks): for layer, block in zip(self.layers, blocks):
# We need to first copy the representation of nodes on the RHS from the # We need to first copy the representation of nodes on the RHS from the
# appropriate nodes on the LHS. # appropriate nodes on the LHS.
# Note that the shape of h is (num_nodes_LHS, D) and the shape of h_dst # Note that the shape of h is (num_nodes_LHS, D) and the shape of h_dst
# would be (num_nodes_RHS, D) # would be (num_nodes_RHS, D)
h_dst = h[:block.number_of_dst_nodes()] h_dst = h[: block.number_of_dst_nodes()]
hbar_src = block.srcdata['hist'] hbar_src = block.srcdata["hist"]
agg_hbar_dst = block.dstdata['agg_hist'] agg_hbar_dst = block.dstdata["agg_hist"]
# Then we compute the updated representation on the RHS. # Then we compute the updated representation on the RHS.
# The shape of h now becomes (num_nodes_RHS, D) # The shape of h now becomes (num_nodes_RHS, D)
h = layer(block, (h, h_dst), (hbar_src, agg_hbar_dst)) h = layer(block, (h, h_dst), (hbar_src, agg_hbar_dst))
block.dstdata['h_new'] = h block.dstdata["h_new"] = h
return h return h
def inference(self, g, x, batch_size, device): def inference(self, g, x, batch_size, device):
...@@ -99,17 +102,22 @@ class SAGE(nn.Module): ...@@ -99,17 +102,22 @@ class SAGE(nn.Module):
nodes = th.arange(g.number_of_nodes()) nodes = th.arange(g.number_of_nodes())
ys = [] ys = []
for l, layer in enumerate(self.layers): for l, layer in enumerate(self.layers):
y = th.zeros(g.number_of_nodes(), self.n_hidden if l != len(self.layers) - 1 else self.n_classes) y = th.zeros(
g.number_of_nodes(),
self.n_hidden if l != len(self.layers) - 1 else self.n_classes,
)
for start in tqdm.trange(0, len(nodes), batch_size): for start in tqdm.trange(0, len(nodes), batch_size):
end = start + batch_size end = start + batch_size
batch_nodes = nodes[start:end] batch_nodes = nodes[start:end]
block = dgl.to_block(dgl.in_subgraph(g, batch_nodes), batch_nodes) block = dgl.to_block(
dgl.in_subgraph(g, batch_nodes), batch_nodes
)
block = block.int().to(device) block = block.int().to(device)
induced_nodes = block.srcdata[dgl.NID] induced_nodes = block.srcdata[dgl.NID]
h = x[induced_nodes].to(device) h = x[induced_nodes].to(device)
h_dst = h[:block.number_of_dst_nodes()] h_dst = h[: block.number_of_dst_nodes()]
h = layer(block, (h, h_dst)) h = layer(block, (h, h_dst))
y[start:end] = h.cpu() y[start:end] = h.cpu()
...@@ -119,7 +127,6 @@ class SAGE(nn.Module): ...@@ -119,7 +127,6 @@ class SAGE(nn.Module):
return y, ys return y, ys
class NeighborSampler(object): class NeighborSampler(object):
def __init__(self, g, fanouts): def __init__(self, g, fanouts):
self.g = g self.g = g
...@@ -143,12 +150,14 @@ class NeighborSampler(object): ...@@ -143,12 +150,14 @@ class NeighborSampler(object):
hist_blocks.insert(0, hist_block) hist_blocks.insert(0, hist_block)
return blocks, hist_blocks return blocks, hist_blocks
def compute_acc(pred, labels): def compute_acc(pred, labels):
""" """
Compute the accuracy of prediction given the labels. Compute the accuracy of prediction given the labels.
""" """
return (th.argmax(pred, dim=1) == labels).float().sum() / len(pred) return (th.argmax(pred, dim=1) == labels).float().sum() / len(pred)
def evaluate(model, g, labels, val_mask, batch_size, device): def evaluate(model, g, labels, val_mask, batch_size, device):
""" """
Evaluate the model on the validation set specified by ``val_mask``. Evaluate the model on the validation set specified by ``val_mask``.
...@@ -161,54 +170,64 @@ def evaluate(model, g, labels, val_mask, batch_size, device): ...@@ -161,54 +170,64 @@ def evaluate(model, g, labels, val_mask, batch_size, device):
""" """
model.eval() model.eval()
with th.no_grad(): with th.no_grad():
inputs = g.ndata['features'] inputs = g.ndata["features"]
pred, _ = model.inference(g, inputs, batch_size, device) pred, _ = model.inference(g, inputs, batch_size, device)
model.train() model.train()
return compute_acc(pred[val_mask], labels[val_mask]) return compute_acc(pred[val_mask], labels[val_mask])
def load_subtensor(g, labels, blocks, hist_blocks, dev_id, aggregation_on_device=False):
def load_subtensor(
g, labels, blocks, hist_blocks, dev_id, aggregation_on_device=False
):
""" """
Copys features and labels of a set of nodes onto GPU. Copys features and labels of a set of nodes onto GPU.
""" """
blocks[0].srcdata['features'] = g.ndata['features'][blocks[0].srcdata[dgl.NID]] blocks[0].srcdata["features"] = g.ndata["features"][
blocks[-1].dstdata['label'] = labels[blocks[-1].dstdata[dgl.NID]] blocks[0].srcdata[dgl.NID]
]
blocks[-1].dstdata["label"] = labels[blocks[-1].dstdata[dgl.NID]]
ret_blocks = [] ret_blocks = []
ret_hist_blocks = [] ret_hist_blocks = []
for i, (block, hist_block) in enumerate(zip(blocks, hist_blocks)): for i, (block, hist_block) in enumerate(zip(blocks, hist_blocks)):
hist_col = 'features' if i == 0 else 'hist_%d' % i hist_col = "features" if i == 0 else "hist_%d" % i
block.srcdata['hist'] = g.ndata[hist_col][block.srcdata[dgl.NID]] block.srcdata["hist"] = g.ndata[hist_col][block.srcdata[dgl.NID]]
# Aggregate history # Aggregate history
hist_block.srcdata['hist'] = g.ndata[hist_col][hist_block.srcdata[dgl.NID]] hist_block.srcdata["hist"] = g.ndata[hist_col][
hist_block.srcdata[dgl.NID]
]
if aggregation_on_device: if aggregation_on_device:
hist_block = hist_block.to(dev_id) hist_block = hist_block.to(dev_id)
hist_block.update_all(fn.copy_u('hist', 'm'), fn.mean('m', 'agg_hist')) hist_block.update_all(fn.copy_u("hist", "m"), fn.mean("m", "agg_hist"))
block = block.int().to(dev_id) block = block.int().to(dev_id)
if not aggregation_on_device: if not aggregation_on_device:
hist_block = hist_block.to(dev_id) hist_block = hist_block.to(dev_id)
block.dstdata['agg_hist'] = hist_block.dstdata['agg_hist'] block.dstdata["agg_hist"] = hist_block.dstdata["agg_hist"]
ret_blocks.append(block) ret_blocks.append(block)
ret_hist_blocks.append(hist_block) ret_hist_blocks.append(hist_block)
return ret_blocks, ret_hist_blocks return ret_blocks, ret_hist_blocks
def init_history(g, model, dev_id): def init_history(g, model, dev_id):
with th.no_grad(): with th.no_grad():
history = model.inference(g, g.ndata['features'], 1000, dev_id)[1] history = model.inference(g, g.ndata["features"], 1000, dev_id)[1]
for layer in range(args.num_layers + 1): for layer in range(args.num_layers + 1):
if layer > 0: if layer > 0:
hist_col = 'hist_%d' % layer hist_col = "hist_%d" % layer
g.ndata['hist_%d' % layer] = history[layer - 1] g.ndata["hist_%d" % layer] = history[layer - 1]
def update_history(g, blocks): def update_history(g, blocks):
with th.no_grad(): with th.no_grad():
for i, block in enumerate(blocks): for i, block in enumerate(blocks):
ids = block.dstdata[dgl.NID].cpu() ids = block.dstdata[dgl.NID].cpu()
hist_col = 'hist_%d' % (i + 1) hist_col = "hist_%d" % (i + 1)
h_new = block.dstdata['h_new'].cpu() h_new = block.dstdata["h_new"].cpu()
g.ndata[hist_col][ids] = h_new g.ndata[hist_col][ids] = h_new
def run(args, dev_id, data): def run(args, dev_id, data):
dropout = 0.2 dropout = 0.2
...@@ -220,7 +239,7 @@ def run(args, dev_id, data): ...@@ -220,7 +239,7 @@ def run(args, dev_id, data):
val_nid = val_mask.nonzero().squeeze() val_nid = val_mask.nonzero().squeeze()
# Create sampler # Create sampler
sampler = NeighborSampler(g, [int(_) for _ in args.fan_out.split(',')]) sampler = NeighborSampler(g, [int(_) for _ in args.fan_out.split(",")])
# Create PyTorch DataLoader for constructing blocks # Create PyTorch DataLoader for constructing blocks
dataloader = DataLoader( dataloader = DataLoader(
...@@ -229,7 +248,8 @@ def run(args, dev_id, data): ...@@ -229,7 +248,8 @@ def run(args, dev_id, data):
collate_fn=sampler.sample_blocks, collate_fn=sampler.sample_blocks,
shuffle=True, shuffle=True,
drop_last=False, drop_last=False,
num_workers=args.num_workers_per_gpu) num_workers=args.num_workers_per_gpu,
)
# Define model # Define model
model = SAGE(in_feats, args.num_hidden, n_classes, args.num_layers, F.relu) model = SAGE(in_feats, args.num_hidden, n_classes, args.num_layers, F.relu)
...@@ -258,14 +278,16 @@ def run(args, dev_id, data): ...@@ -258,14 +278,16 @@ def run(args, dev_id, data):
input_nodes = blocks[0].srcdata[dgl.NID] input_nodes = blocks[0].srcdata[dgl.NID]
seeds = blocks[-1].dstdata[dgl.NID] seeds = blocks[-1].dstdata[dgl.NID]
blocks, hist_blocks = load_subtensor(g, labels, blocks, hist_blocks, dev_id, True) blocks, hist_blocks = load_subtensor(
g, labels, blocks, hist_blocks, dev_id, True
)
# forward # forward
batch_pred = model(blocks) batch_pred = model(blocks)
# update history # update history
update_history(g, blocks) update_history(g, blocks)
# compute loss # compute loss
batch_labels = blocks[-1].dstdata['label'] batch_labels = blocks[-1].dstdata["label"]
loss = loss_fcn(batch_pred, batch_labels) loss = loss_fcn(batch_pred, batch_labels)
# backward # backward
optimizer.zero_grad() optimizer.zero_grad()
...@@ -274,45 +296,55 @@ def run(args, dev_id, data): ...@@ -274,45 +296,55 @@ def run(args, dev_id, data):
iter_tput.append(len(seeds) / (time.time() - tic_step)) iter_tput.append(len(seeds) / (time.time() - tic_step))
if step % args.log_every == 0: if step % args.log_every == 0:
acc = compute_acc(batch_pred, batch_labels) acc = compute_acc(batch_pred, batch_labels)
print('Epoch {:05d} | Step {:05d} | Loss {:.4f} | Train Acc {:.4f} | Speed (samples/sec) {:.4f}'.format( print(
epoch, step, loss.item(), acc.item(), np.mean(iter_tput[3:]))) "Epoch {:05d} | Step {:05d} | Loss {:.4f} | Train Acc {:.4f} | Speed (samples/sec) {:.4f}".format(
epoch,
step,
loss.item(),
acc.item(),
np.mean(iter_tput[3:]),
)
)
tic_step = time.time() tic_step = time.time()
toc = time.time() toc = time.time()
print('Epoch Time(s): {:.4f}'.format(toc - tic)) print("Epoch Time(s): {:.4f}".format(toc - tic))
if epoch >= 5: if epoch >= 5:
avg += toc - tic avg += toc - tic
if epoch % args.eval_every == 0 and epoch != 0: if epoch % args.eval_every == 0 and epoch != 0:
model.eval() model.eval()
eval_acc = evaluate(model, g, labels, val_nid, args.val_batch_size, dev_id) eval_acc = evaluate(
print('Eval Acc {:.4f}'.format(eval_acc)) model, g, labels, val_nid, args.val_batch_size, dev_id
)
print("Eval Acc {:.4f}".format(eval_acc))
print("Avg epoch time: {}".format(avg / (epoch - 4)))
print('Avg epoch time: {}'.format(avg / (epoch - 4)))
if __name__ == '__main__': if __name__ == "__main__":
argparser = argparse.ArgumentParser("multi-gpu training") argparser = argparse.ArgumentParser("multi-gpu training")
argparser.add_argument('--gpu', type=str, default='0') argparser.add_argument("--gpu", type=str, default="0")
argparser.add_argument('--num-epochs', type=int, default=20) argparser.add_argument("--num-epochs", type=int, default=20)
argparser.add_argument('--num-hidden', type=int, default=16) argparser.add_argument("--num-hidden", type=int, default=16)
argparser.add_argument('--num-layers', type=int, default=2) argparser.add_argument("--num-layers", type=int, default=2)
argparser.add_argument('--fan-out', type=str, default='1,1') argparser.add_argument("--fan-out", type=str, default="1,1")
argparser.add_argument('--batch-size', type=int, default=1000) argparser.add_argument("--batch-size", type=int, default=1000)
argparser.add_argument('--val-batch-size', type=int, default=1000) argparser.add_argument("--val-batch-size", type=int, default=1000)
argparser.add_argument('--log-every', type=int, default=20) argparser.add_argument("--log-every", type=int, default=20)
argparser.add_argument('--eval-every', type=int, default=5) argparser.add_argument("--eval-every", type=int, default=5)
argparser.add_argument('--lr', type=float, default=0.003) argparser.add_argument("--lr", type=float, default=0.003)
argparser.add_argument('--num-workers-per-gpu', type=int, default=0) argparser.add_argument("--num-workers-per-gpu", type=int, default=0)
args = argparser.parse_args() args = argparser.parse_args()
# load reddit data # load reddit data
data = RedditDataset(self_loop=True) data = RedditDataset(self_loop=True)
n_classes = data.num_classes n_classes = data.num_classes
g = data[0] g = data[0]
features = g.ndata['feat'] features = g.ndata["feat"]
in_feats = features.shape[1] in_feats = features.shape[1]
labels = g.ndata['label'] labels = g.ndata["label"]
train_mask = g.ndata['train_mask'] train_mask = g.ndata["train_mask"]
val_mask = g.ndata['val_mask'] val_mask = g.ndata["val_mask"]
g.ndata['features'] = features g.ndata["features"] = features
g.create_formats_() g.create_formats_()
# Pack data # Pack data
data = train_mask, val_mask, in_feats, labels, n_classes, g data = train_mask, val_mask, in_feats, labels, n_classes, g
......
import dgl import argparse
import math
import time
import traceback
import numpy as np import numpy as np
import torch as th import torch as th
import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.optim as optim import torch.optim as optim
import torch.multiprocessing as mp import tqdm
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader
import dgl
import dgl.function as fn import dgl.function as fn
import dgl.nn.pytorch as dglnn import dgl.nn.pytorch as dglnn
import time
import argparse
import tqdm
import traceback
import math
from dgl.data import RedditDataset from dgl.data import RedditDataset
from torch.utils.data import DataLoader
from torch.nn.parallel import DistributedDataParallel
class SAGEConvWithCV(nn.Module): class SAGEConvWithCV(nn.Module):
def __init__(self, in_feats, out_feats, activation): def __init__(self, in_feats, out_feats, activation):
...@@ -24,7 +27,7 @@ class SAGEConvWithCV(nn.Module): ...@@ -24,7 +27,7 @@ class SAGEConvWithCV(nn.Module):
self.reset_parameters() self.reset_parameters()
def reset_parameters(self): def reset_parameters(self):
gain = nn.init.calculate_gain('relu') gain = nn.init.calculate_gain("relu")
nn.init.xavier_uniform_(self.W.weight, gain=gain) nn.init.xavier_uniform_(self.W.weight, gain=gain)
nn.init.constant_(self.W.bias, 0) nn.init.constant_(self.W.bias, 0)
...@@ -33,10 +36,14 @@ class SAGEConvWithCV(nn.Module): ...@@ -33,10 +36,14 @@ class SAGEConvWithCV(nn.Module):
with block.local_scope(): with block.local_scope():
H_src, H_dst = H H_src, H_dst = H
HBar_src, agg_HBar_dst = HBar HBar_src, agg_HBar_dst = HBar
block.dstdata['agg_hbar'] = agg_HBar_dst block.dstdata["agg_hbar"] = agg_HBar_dst
block.srcdata['hdelta'] = H_src - HBar_src block.srcdata["hdelta"] = H_src - HBar_src
block.update_all(fn.copy_u('hdelta', 'm'), fn.mean('m', 'hdelta_new')) block.update_all(
h_neigh = block.dstdata['agg_hbar'] + block.dstdata['hdelta_new'] fn.copy_u("hdelta", "m"), fn.mean("m", "hdelta_new")
)
h_neigh = (
block.dstdata["agg_hbar"] + block.dstdata["hdelta_new"]
)
h = self.W(th.cat([H_dst, h_neigh], 1)) h = self.W(th.cat([H_dst, h_neigh], 1))
if self.activation is not None: if self.activation is not None:
h = self.activation(h) h = self.activation(h)
...@@ -44,21 +51,17 @@ class SAGEConvWithCV(nn.Module): ...@@ -44,21 +51,17 @@ class SAGEConvWithCV(nn.Module):
else: else:
with block.local_scope(): with block.local_scope():
H_src, H_dst = H H_src, H_dst = H
block.srcdata['h'] = H_src block.srcdata["h"] = H_src
block.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'h_new')) block.update_all(fn.copy_u("h", "m"), fn.mean("m", "h_new"))
h_neigh = block.dstdata['h_new'] h_neigh = block.dstdata["h_new"]
h = self.W(th.cat([H_dst, h_neigh], 1)) h = self.W(th.cat([H_dst, h_neigh], 1))
if self.activation is not None: if self.activation is not None:
h = self.activation(h) h = self.activation(h)
return h return h
class SAGE(nn.Module): class SAGE(nn.Module):
def __init__(self, def __init__(self, in_feats, n_hidden, n_classes, n_layers, activation):
in_feats,
n_hidden,
n_classes,
n_layers,
activation):
super().__init__() super().__init__()
self.n_layers = n_layers self.n_layers = n_layers
self.n_hidden = n_hidden self.n_hidden = n_hidden
...@@ -70,20 +73,20 @@ class SAGE(nn.Module): ...@@ -70,20 +73,20 @@ class SAGE(nn.Module):
self.layers.append(SAGEConvWithCV(n_hidden, n_classes, None)) self.layers.append(SAGEConvWithCV(n_hidden, n_classes, None))
def forward(self, blocks): def forward(self, blocks):
h = blocks[0].srcdata['features'] h = blocks[0].srcdata["features"]
updates = [] updates = []
for layer, block in zip(self.layers, blocks): for layer, block in zip(self.layers, blocks):
# We need to first copy the representation of nodes on the RHS from the # We need to first copy the representation of nodes on the RHS from the
# appropriate nodes on the LHS. # appropriate nodes on the LHS.
# Note that the shape of h is (num_nodes_LHS, D) and the shape of h_dst # Note that the shape of h is (num_nodes_LHS, D) and the shape of h_dst
# would be (num_nodes_RHS, D) # would be (num_nodes_RHS, D)
h_dst = h[:block.number_of_dst_nodes()] h_dst = h[: block.number_of_dst_nodes()]
hbar_src = block.srcdata['hist'] hbar_src = block.srcdata["hist"]
agg_hbar_dst = block.dstdata['agg_hist'] agg_hbar_dst = block.dstdata["agg_hist"]
# Then we compute the updated representation on the RHS. # Then we compute the updated representation on the RHS.
# The shape of h now becomes (num_nodes_RHS, D) # The shape of h now becomes (num_nodes_RHS, D)
h = layer(block, (h, h_dst), (hbar_src, agg_hbar_dst)) h = layer(block, (h, h_dst), (hbar_src, agg_hbar_dst))
block.dstdata['h_new'] = h block.dstdata["h_new"] = h
return h return h
def inference(self, g, x, batch_size, device): def inference(self, g, x, batch_size, device):
...@@ -102,17 +105,19 @@ class SAGE(nn.Module): ...@@ -102,17 +105,19 @@ class SAGE(nn.Module):
# TODO: can we standardize this? # TODO: can we standardize this?
nodes = th.arange(g.number_of_nodes()) nodes = th.arange(g.number_of_nodes())
for l, layer in enumerate(self.layers): for l, layer in enumerate(self.layers):
y = g.ndata['hist_%d' % (l + 1)] y = g.ndata["hist_%d" % (l + 1)]
for start in tqdm.trange(0, len(nodes), batch_size): for start in tqdm.trange(0, len(nodes), batch_size):
end = start + batch_size end = start + batch_size
batch_nodes = nodes[start:end] batch_nodes = nodes[start:end]
block = dgl.to_block(dgl.in_subgraph(g, batch_nodes), batch_nodes) block = dgl.to_block(
dgl.in_subgraph(g, batch_nodes), batch_nodes
)
induced_nodes = block.srcdata[dgl.NID] induced_nodes = block.srcdata[dgl.NID]
h = x[induced_nodes].to(device) h = x[induced_nodes].to(device)
block = block.to(device) block = block.to(device)
h_dst = h[:block.number_of_dst_nodes()] h_dst = h[: block.number_of_dst_nodes()]
h = layer(block, (h, h_dst)) h = layer(block, (h, h_dst))
y[start:end] = h.cpu() y[start:end] = h.cpu()
...@@ -121,7 +126,6 @@ class SAGE(nn.Module): ...@@ -121,7 +126,6 @@ class SAGE(nn.Module):
return y return y
class NeighborSampler(object): class NeighborSampler(object):
def __init__(self, g, fanouts): def __init__(self, g, fanouts):
self.g = g self.g = g
...@@ -146,12 +150,14 @@ class NeighborSampler(object): ...@@ -146,12 +150,14 @@ class NeighborSampler(object):
hist_blocks.insert(0, hist_block) hist_blocks.insert(0, hist_block)
return blocks, hist_blocks return blocks, hist_blocks
def compute_acc(pred, labels): def compute_acc(pred, labels):
""" """
Compute the accuracy of prediction given the labels. Compute the accuracy of prediction given the labels.
""" """
return (th.argmax(pred, dim=1) == labels).float().sum() / len(pred) return (th.argmax(pred, dim=1) == labels).float().sum() / len(pred)
def evaluate(model, g, labels, val_mask, batch_size, device): def evaluate(model, g, labels, val_mask, batch_size, device):
""" """
Evaluate the model on the validation set specified by ``val_mask``. Evaluate the model on the validation set specified by ``val_mask``.
...@@ -164,69 +170,89 @@ def evaluate(model, g, labels, val_mask, batch_size, device): ...@@ -164,69 +170,89 @@ def evaluate(model, g, labels, val_mask, batch_size, device):
""" """
model.eval() model.eval()
with th.no_grad(): with th.no_grad():
inputs = g.ndata['features'] inputs = g.ndata["features"]
pred = model.inference(g, inputs, batch_size, device) # also recomputes history tensors pred = model.inference(
g, inputs, batch_size, device
) # also recomputes history tensors
model.train() model.train()
return compute_acc(pred[val_mask], labels[val_mask]) return compute_acc(pred[val_mask], labels[val_mask])
def load_subtensor(g, labels, blocks, hist_blocks, dev_id, aggregation_on_device=False):
def load_subtensor(
g, labels, blocks, hist_blocks, dev_id, aggregation_on_device=False
):
""" """
Copys features and labels of a set of nodes onto GPU. Copys features and labels of a set of nodes onto GPU.
""" """
blocks[0].srcdata['features'] = g.ndata['features'][blocks[0].srcdata[dgl.NID]] blocks[0].srcdata["features"] = g.ndata["features"][
blocks[-1].dstdata['label'] = labels[blocks[-1].dstdata[dgl.NID]] blocks[0].srcdata[dgl.NID]
]
blocks[-1].dstdata["label"] = labels[blocks[-1].dstdata[dgl.NID]]
ret_blocks = [] ret_blocks = []
ret_hist_blocks = [] ret_hist_blocks = []
for i, (block, hist_block) in enumerate(zip(blocks, hist_blocks)): for i, (block, hist_block) in enumerate(zip(blocks, hist_blocks)):
hist_col = 'features' if i == 0 else 'hist_%d' % i hist_col = "features" if i == 0 else "hist_%d" % i
block.srcdata['hist'] = g.ndata[hist_col][block.srcdata[dgl.NID]] block.srcdata["hist"] = g.ndata[hist_col][block.srcdata[dgl.NID]]
# Aggregate history # Aggregate history
hist_block.srcdata['hist'] = g.ndata[hist_col][hist_block.srcdata[dgl.NID]] hist_block.srcdata["hist"] = g.ndata[hist_col][
hist_block.srcdata[dgl.NID]
]
if aggregation_on_device: if aggregation_on_device:
hist_block = hist_block.to(dev_id) hist_block = hist_block.to(dev_id)
hist_block.srcdata['hist'] = hist_block.srcdata['hist'] hist_block.srcdata["hist"] = hist_block.srcdata["hist"]
hist_block.update_all(fn.copy_u('hist', 'm'), fn.mean('m', 'agg_hist')) hist_block.update_all(fn.copy_u("hist", "m"), fn.mean("m", "agg_hist"))
block = block.to(dev_id) block = block.to(dev_id)
if not aggregation_on_device: if not aggregation_on_device:
hist_block = hist_block.to(dev_id) hist_block = hist_block.to(dev_id)
block.dstdata['agg_hist'] = hist_block.dstdata['agg_hist'] block.dstdata["agg_hist"] = hist_block.dstdata["agg_hist"]
ret_blocks.append(block) ret_blocks.append(block)
ret_hist_blocks.append(hist_block) ret_hist_blocks.append(hist_block)
return ret_blocks, ret_hist_blocks return ret_blocks, ret_hist_blocks
def create_history_storage(g, args, n_classes): def create_history_storage(g, args, n_classes):
# Initialize history storage # Initialize history storage
for l in range(args.num_layers): for l in range(args.num_layers):
dim = args.num_hidden if l != args.num_layers - 1 else n_classes dim = args.num_hidden if l != args.num_layers - 1 else n_classes
g.ndata['hist_%d' % (l + 1)] = th.zeros(g.number_of_nodes(), dim).share_memory_() g.ndata["hist_%d" % (l + 1)] = th.zeros(
g.number_of_nodes(), dim
).share_memory_()
def init_history(g, model, dev_id, batch_size): def init_history(g, model, dev_id, batch_size):
with th.no_grad(): with th.no_grad():
model.inference(g, g.ndata['features'], batch_size, dev_id) # replaces hist_i features in-place model.inference(
g, g.ndata["features"], batch_size, dev_id
) # replaces hist_i features in-place
def update_history(g, blocks): def update_history(g, blocks):
with th.no_grad(): with th.no_grad():
for i, block in enumerate(blocks): for i, block in enumerate(blocks):
ids = block.dstdata[dgl.NID].cpu() ids = block.dstdata[dgl.NID].cpu()
hist_col = 'hist_%d' % (i + 1) hist_col = "hist_%d" % (i + 1)
h_new = block.dstdata['h_new'].cpu() h_new = block.dstdata["h_new"].cpu()
g.ndata[hist_col][ids] = h_new g.ndata[hist_col][ids] = h_new
def run(proc_id, n_gpus, args, devices, data): def run(proc_id, n_gpus, args, devices, data):
dropout = 0.2 dropout = 0.2
dev_id = devices[proc_id] dev_id = devices[proc_id]
if n_gpus > 1: if n_gpus > 1:
dist_init_method = 'tcp://{master_ip}:{master_port}'.format( dist_init_method = "tcp://{master_ip}:{master_port}".format(
master_ip='127.0.0.1', master_port='12345') master_ip="127.0.0.1", master_port="12345"
)
world_size = n_gpus world_size = n_gpus
th.distributed.init_process_group(backend="nccl", th.distributed.init_process_group(
init_method=dist_init_method, backend="nccl",
world_size=world_size, init_method=dist_init_method,
rank=proc_id) world_size=world_size,
rank=proc_id,
)
th.cuda.set_device(dev_id) th.cuda.set_device(dev_id)
# Unpack data # Unpack data
...@@ -235,17 +261,20 @@ def run(proc_id, n_gpus, args, devices, data): ...@@ -235,17 +261,20 @@ def run(proc_id, n_gpus, args, devices, data):
val_nid = val_mask.nonzero().squeeze() val_nid = val_mask.nonzero().squeeze()
# Create sampler # Create sampler
sampler = NeighborSampler(g, [int(_) for _ in args.fan_out.split(',')]) sampler = NeighborSampler(g, [int(_) for _ in args.fan_out.split(",")])
# Create PyTorch DataLoader for constructing blocks # Create PyTorch DataLoader for constructing blocks
if n_gpus > 1: if n_gpus > 1:
dist_sampler = th.utils.data.distributed.DistributedSampler(train_nid.numpy(), shuffle=True, drop_last=False) dist_sampler = th.utils.data.distributed.DistributedSampler(
train_nid.numpy(), shuffle=True, drop_last=False
)
dataloader = DataLoader( dataloader = DataLoader(
dataset=train_nid.numpy(), dataset=train_nid.numpy(),
batch_size=args.batch_size, batch_size=args.batch_size,
collate_fn=sampler.sample_blocks, collate_fn=sampler.sample_blocks,
sampler=dist_sampler, sampler=dist_sampler,
num_workers=args.num_workers_per_gpu) num_workers=args.num_workers_per_gpu,
)
else: else:
dataloader = DataLoader( dataloader = DataLoader(
dataset=train_nid.numpy(), dataset=train_nid.numpy(),
...@@ -253,7 +282,8 @@ def run(proc_id, n_gpus, args, devices, data): ...@@ -253,7 +282,8 @@ def run(proc_id, n_gpus, args, devices, data):
collate_fn=sampler.sample_blocks, collate_fn=sampler.sample_blocks,
shuffle=True, shuffle=True,
drop_last=False, drop_last=False,
num_workers=args.num_workers_per_gpu) num_workers=args.num_workers_per_gpu,
)
# Define model # Define model
model = SAGE(in_feats, args.num_hidden, n_classes, args.num_layers, F.relu) model = SAGE(in_feats, args.num_hidden, n_classes, args.num_layers, F.relu)
...@@ -261,7 +291,9 @@ def run(proc_id, n_gpus, args, devices, data): ...@@ -261,7 +291,9 @@ def run(proc_id, n_gpus, args, devices, data):
# Move the model to GPU and define optimizer # Move the model to GPU and define optimizer
model = model.to(dev_id) model = model.to(dev_id)
if n_gpus > 1: if n_gpus > 1:
model = DistributedDataParallel(model, device_ids=[dev_id], output_device=dev_id) model = DistributedDataParallel(
model, device_ids=[dev_id], output_device=dev_id
)
loss_fcn = nn.CrossEntropyLoss() loss_fcn = nn.CrossEntropyLoss()
loss_fcn = loss_fcn.to(dev_id) loss_fcn = loss_fcn.to(dev_id)
optimizer = optim.Adam(model.parameters(), lr=args.lr) optimizer = optim.Adam(model.parameters(), lr=args.lr)
...@@ -292,14 +324,16 @@ def run(proc_id, n_gpus, args, devices, data): ...@@ -292,14 +324,16 @@ def run(proc_id, n_gpus, args, devices, data):
# The nodes for output lies at the RHS side of the last block. # The nodes for output lies at the RHS side of the last block.
seeds = blocks[-1].dstdata[dgl.NID] seeds = blocks[-1].dstdata[dgl.NID]
blocks, hist_blocks = load_subtensor(g, labels, blocks, hist_blocks, dev_id, True) blocks, hist_blocks = load_subtensor(
g, labels, blocks, hist_blocks, dev_id, True
)
# forward # forward
batch_pred = model(blocks) batch_pred = model(blocks)
# update history # update history
update_history(g, blocks) update_history(g, blocks)
# compute loss # compute loss
batch_labels = blocks[-1].dstdata['label'] batch_labels = blocks[-1].dstdata["label"]
loss = loss_fcn(batch_pred, batch_labels) loss = loss_fcn(batch_pred, batch_labels)
# backward # backward
optimizer.zero_grad() optimizer.zero_grad()
...@@ -309,56 +343,70 @@ def run(proc_id, n_gpus, args, devices, data): ...@@ -309,56 +343,70 @@ def run(proc_id, n_gpus, args, devices, data):
iter_tput.append(len(seeds) * n_gpus / (time.time() - tic_step)) iter_tput.append(len(seeds) * n_gpus / (time.time() - tic_step))
if step % args.log_every == 0 and proc_id == 0: if step % args.log_every == 0 and proc_id == 0:
acc = compute_acc(batch_pred, batch_labels) acc = compute_acc(batch_pred, batch_labels)
print('Epoch {:05d} | Step {:05d} | Loss {:.4f} | Train Acc {:.4f} | Speed (samples/sec) {:.4f}'.format( print(
epoch, step, loss.item(), acc.item(), np.mean(iter_tput[3:]))) "Epoch {:05d} | Step {:05d} | Loss {:.4f} | Train Acc {:.4f} | Speed (samples/sec) {:.4f}".format(
epoch,
step,
loss.item(),
acc.item(),
np.mean(iter_tput[3:]),
)
)
if n_gpus > 1: if n_gpus > 1:
th.distributed.barrier() th.distributed.barrier()
toc = time.time() toc = time.time()
if proc_id == 0: if proc_id == 0:
print('Epoch Time(s): {:.4f}'.format(toc - tic)) print("Epoch Time(s): {:.4f}".format(toc - tic))
if epoch >= 5: if epoch >= 5:
avg += toc - tic avg += toc - tic
if epoch % args.eval_every == 0 and epoch != 0: if epoch % args.eval_every == 0 and epoch != 0:
model.eval() model.eval()
eval_acc = evaluate( eval_acc = evaluate(
model if n_gpus == 1 else model.module, g, labels, val_nid, args.val_batch_size, dev_id) model if n_gpus == 1 else model.module,
print('Eval Acc {:.4f}'.format(eval_acc)) g,
labels,
val_nid,
args.val_batch_size,
dev_id,
)
print("Eval Acc {:.4f}".format(eval_acc))
if n_gpus > 1: if n_gpus > 1:
th.distributed.barrier() th.distributed.barrier()
if proc_id == 0: if proc_id == 0:
print('Avg epoch time: {}'.format(avg / (epoch - 4))) print("Avg epoch time: {}".format(avg / (epoch - 4)))
if __name__ == '__main__':
if __name__ == "__main__":
argparser = argparse.ArgumentParser("multi-gpu training") argparser = argparse.ArgumentParser("multi-gpu training")
argparser.add_argument('--gpu', type=str, default='0') argparser.add_argument("--gpu", type=str, default="0")
argparser.add_argument('--num-epochs', type=int, default=20) argparser.add_argument("--num-epochs", type=int, default=20)
argparser.add_argument('--num-hidden', type=int, default=16) argparser.add_argument("--num-hidden", type=int, default=16)
argparser.add_argument('--num-layers', type=int, default=2) argparser.add_argument("--num-layers", type=int, default=2)
argparser.add_argument('--fan-out', type=str, default='1,1') argparser.add_argument("--fan-out", type=str, default="1,1")
argparser.add_argument('--batch-size', type=int, default=1000) argparser.add_argument("--batch-size", type=int, default=1000)
argparser.add_argument('--val-batch-size', type=int, default=1000) argparser.add_argument("--val-batch-size", type=int, default=1000)
argparser.add_argument('--log-every', type=int, default=20) argparser.add_argument("--log-every", type=int, default=20)
argparser.add_argument('--eval-every', type=int, default=5) argparser.add_argument("--eval-every", type=int, default=5)
argparser.add_argument('--lr', type=float, default=0.003) argparser.add_argument("--lr", type=float, default=0.003)
argparser.add_argument('--num-workers-per-gpu', type=int, default=0) argparser.add_argument("--num-workers-per-gpu", type=int, default=0)
args = argparser.parse_args() args = argparser.parse_args()
devices = list(map(int, args.gpu.split(','))) devices = list(map(int, args.gpu.split(",")))
n_gpus = len(devices) n_gpus = len(devices)
# load reddit data # load reddit data
data = RedditDataset(self_loop=True) data = RedditDataset(self_loop=True)
n_classes = data.num_classes n_classes = data.num_classes
g = data[0] g = data[0]
features = g.ndata['feat'] features = g.ndata["feat"]
in_feats = features.shape[1] in_feats = features.shape[1]
labels = g.ndata['label'] labels = g.ndata["label"]
train_mask = g.ndata['train_mask'] train_mask = g.ndata["train_mask"]
val_mask = g.ndata['val_mask'] val_mask = g.ndata["val_mask"]
g.ndata['features'] = features.share_memory_() g.ndata["features"] = features.share_memory_()
create_history_storage(g, args, n_classes) create_history_storage(g, args, n_classes)
# Create csr/coo/csc formats before launching training processes with multi-gpu. # Create csr/coo/csc formats before launching training processes with multi-gpu.
......
...@@ -7,19 +7,21 @@ Papers: https://arxiv.org/abs/1809.10341 ...@@ -7,19 +7,21 @@ Papers: https://arxiv.org/abs/1809.10341
Author's code: https://github.com/PetarV-/DGI Author's code: https://github.com/PetarV-/DGI
""" """
import tensorflow as tf
from tensorflow.keras import layers
import numpy as np
import math import math
import numpy as np
import tensorflow as tf
from gcn import GCN from gcn import GCN
from tensorflow.keras import layers
class Encoder(layers.Layer): class Encoder(layers.Layer):
def __init__(self, g, in_feats, n_hidden, n_layers, activation, dropout): def __init__(self, g, in_feats, n_hidden, n_layers, activation, dropout):
super(Encoder, self).__init__() super(Encoder, self).__init__()
self.g = g self.g = g
self.conv = GCN(g, in_feats, n_hidden, n_hidden, self.conv = GCN(
n_layers, activation, dropout) g, in_feats, n_hidden, n_hidden, n_layers, activation, dropout
)
def call(self, features, corrupt=False): def call(self, features, corrupt=False):
if corrupt: if corrupt:
...@@ -33,21 +35,26 @@ class Discriminator(layers.Layer): ...@@ -33,21 +35,26 @@ class Discriminator(layers.Layer):
def __init__(self, n_hidden): def __init__(self, n_hidden):
super(Discriminator, self).__init__() super(Discriminator, self).__init__()
uinit = tf.keras.initializers.RandomUniform( uinit = tf.keras.initializers.RandomUniform(
-1.0/math.sqrt(n_hidden), 1.0/math.sqrt(n_hidden)) -1.0 / math.sqrt(n_hidden), 1.0 / math.sqrt(n_hidden)
self.weight = tf.Variable(initial_value=uinit( )
shape=(n_hidden, n_hidden), dtype='float32'), trainable=True) self.weight = tf.Variable(
initial_value=uinit(shape=(n_hidden, n_hidden), dtype="float32"),
trainable=True,
)
def call(self, features, summary): def call(self, features, summary):
features = tf.matmul(features, tf.matmul( features = tf.matmul(
self.weight, tf.expand_dims(summary, -1))) features, tf.matmul(self.weight, tf.expand_dims(summary, -1))
)
return features return features
class DGI(tf.keras.Model): class DGI(tf.keras.Model):
def __init__(self, g, in_feats, n_hidden, n_layers, activation, dropout): def __init__(self, g, in_feats, n_hidden, n_layers, activation, dropout):
super(DGI, self).__init__() super(DGI, self).__init__()
self.encoder = Encoder(g, in_feats, n_hidden, self.encoder = Encoder(
n_layers, activation, dropout) g, in_feats, n_hidden, n_layers, activation, dropout
)
self.discriminator = Discriminator(n_hidden) self.discriminator = Discriminator(n_hidden)
self.loss = tf.nn.sigmoid_cross_entropy_with_logits self.loss = tf.nn.sigmoid_cross_entropy_with_logits
...@@ -59,7 +66,7 @@ class DGI(tf.keras.Model): ...@@ -59,7 +66,7 @@ class DGI(tf.keras.Model):
positive = self.discriminator(positive, summary) positive = self.discriminator(positive, summary)
negative = self.discriminator(negative, summary) negative = self.discriminator(negative, summary)
l1 = self.loss(tf.ones(positive.shape),positive) l1 = self.loss(tf.ones(positive.shape), positive)
l2 = self.loss(tf.zeros(negative.shape), negative) l2 = self.loss(tf.zeros(negative.shape), negative)
return tf.reduce_mean(l1) + tf.reduce_mean(l2) return tf.reduce_mean(l1) + tf.reduce_mean(l2)
......
...@@ -6,23 +6,21 @@ from tensorflow.keras import layers ...@@ -6,23 +6,21 @@ from tensorflow.keras import layers
from dgl.nn.tensorflow import GraphConv from dgl.nn.tensorflow import GraphConv
class GCN(layers.Layer): class GCN(layers.Layer):
def __init__(self, def __init__(
g, self, g, in_feats, n_hidden, n_classes, n_layers, activation, dropout
in_feats, ):
n_hidden,
n_classes,
n_layers,
activation,
dropout):
super(GCN, self).__init__() super(GCN, self).__init__()
self.g = g self.g = g
self.layers =[] self.layers = []
# input layer # input layer
self.layers.append(GraphConv(in_feats, n_hidden, activation=activation)) self.layers.append(GraphConv(in_feats, n_hidden, activation=activation))
# hidden layers # hidden layers
for i in range(n_layers - 1): for i in range(n_layers - 1):
self.layers.append(GraphConv(n_hidden, n_hidden, activation=activation)) self.layers.append(
GraphConv(n_hidden, n_hidden, activation=activation)
)
# output layer # output layer
self.layers.append(GraphConv(n_hidden, n_classes)) self.layers.append(GraphConv(n_hidden, n_classes))
self.dropout = layers.Dropout(dropout) self.dropout = layers.Dropout(dropout)
......
import argparse import argparse
import time import time
import numpy as np
import networkx as nx import networkx as nx
import numpy as np
import tensorflow as tf import tensorflow as tf
from dgi import DGI, Classifier
from tensorflow.keras import layers from tensorflow.keras import layers
import dgl import dgl
from dgl.data import register_data_args from dgl.data import (
from dgl.data import CoraGraphDataset, CiteseerGraphDataset, PubmedGraphDataset CiteseerGraphDataset,
from dgi import DGI, Classifier CoraGraphDataset,
PubmedGraphDataset,
register_data_args,
)
def evaluate(model, features, labels, mask): def evaluate(model, features, labels, mask):
...@@ -21,14 +27,14 @@ def evaluate(model, features, labels, mask): ...@@ -21,14 +27,14 @@ def evaluate(model, features, labels, mask):
def main(args): def main(args):
# load and preprocess dataset # load and preprocess dataset
if args.dataset == 'cora': if args.dataset == "cora":
data = CoraGraphDataset() data = CoraGraphDataset()
elif args.dataset == 'citeseer': elif args.dataset == "citeseer":
data = CiteseerGraphDataset() data = CiteseerGraphDataset()
elif args.dataset == 'pubmed': elif args.dataset == "pubmed":
data = PubmedGraphDataset() data = PubmedGraphDataset()
else: else:
raise ValueError('Unknown dataset: {}'.format(args.dataset)) raise ValueError("Unknown dataset: {}".format(args.dataset))
g = data[0] g = data[0]
if args.gpu < 0: if args.gpu < 0:
...@@ -38,11 +44,11 @@ def main(args): ...@@ -38,11 +44,11 @@ def main(args):
g = g.to(device) g = g.to(device)
with tf.device(device): with tf.device(device):
features = g.ndata['feat'] features = g.ndata["feat"]
labels = g.ndata['label'] labels = g.ndata["label"]
train_mask = g.ndata['train_mask'] train_mask = g.ndata["train_mask"]
val_mask = g.ndata['val_mask'] val_mask = g.ndata["val_mask"]
test_mask = g.ndata['test_mask'] test_mask = g.ndata["test_mask"]
in_feats = features.shape[1] in_feats = features.shape[1]
n_classes = data.num_labels n_classes = data.num_labels
n_edges = data.graph.number_of_edges() n_edges = data.graph.number_of_edges()
...@@ -54,15 +60,18 @@ def main(args): ...@@ -54,15 +60,18 @@ def main(args):
n_edges = g.number_of_edges() n_edges = g.number_of_edges()
# create DGI model # create DGI model
dgi = DGI(g, dgi = DGI(
in_feats, g,
args.n_hidden, in_feats,
args.n_layers, args.n_hidden,
tf.keras.layers.PReLU(alpha_initializer=tf.constant_initializer(0.25)), args.n_layers,
args.dropout) tf.keras.layers.PReLU(
alpha_initializer=tf.constant_initializer(0.25)
dgi_optimizer = tf.keras.optimizers.Adam( ),
learning_rate=args.dgi_lr) args.dropout,
)
dgi_optimizer = tf.keras.optimizers.Adam(learning_rate=args.dgi_lr)
# train deep graph infomax # train deep graph infomax
cnt_wait = 0 cnt_wait = 0
...@@ -80,8 +89,7 @@ def main(args): ...@@ -80,8 +89,7 @@ def main(args):
# of Adam(W) optimizer with PyTorch. And this results in worse results. # of Adam(W) optimizer with PyTorch. And this results in worse results.
# Manually adding weights to the loss to do weight decay solves this problem. # Manually adding weights to the loss to do weight decay solves this problem.
for weight in dgi.trainable_weights: for weight in dgi.trainable_weights:
loss = loss + \ loss = loss + args.weight_decay * tf.nn.l2_loss(weight)
args.weight_decay * tf.nn.l2_loss(weight)
grads = tape.gradient(loss, dgi.trainable_weights) grads = tape.gradient(loss, dgi.trainable_weights)
dgi_optimizer.apply_gradients(zip(grads, dgi.trainable_weights)) dgi_optimizer.apply_gradients(zip(grads, dgi.trainable_weights))
...@@ -89,34 +97,43 @@ def main(args): ...@@ -89,34 +97,43 @@ def main(args):
best = loss best = loss
best_t = epoch best_t = epoch
cnt_wait = 0 cnt_wait = 0
dgi.save_weights('best_dgi.pkl') dgi.save_weights("best_dgi.pkl")
else: else:
cnt_wait += 1 cnt_wait += 1
if cnt_wait == args.patience: if cnt_wait == args.patience:
print('Early stopping!') print("Early stopping!")
break break
if epoch >= 3: if epoch >= 3:
dur.append(time.time() - t0) dur.append(time.time() - t0)
print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | " print(
"ETputs(KTEPS) {:.2f}".format(epoch, np.mean(dur), loss.numpy().item(), "Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | "
n_edges / np.mean(dur) / 1000)) "ETputs(KTEPS) {:.2f}".format(
epoch,
np.mean(dur),
loss.numpy().item(),
n_edges / np.mean(dur) / 1000,
)
)
# create classifier model # create classifier model
classifier = Classifier(args.n_hidden, n_classes) classifier = Classifier(args.n_hidden, n_classes)
classifier_optimizer = tf.keras.optimizers.Adam(learning_rate=args.classifier_lr) classifier_optimizer = tf.keras.optimizers.Adam(
learning_rate=args.classifier_lr
)
# train classifier # train classifier
print('Loading {}th epoch'.format(best_t)) print("Loading {}th epoch".format(best_t))
dgi.load_weights('best_dgi.pkl') dgi.load_weights("best_dgi.pkl")
embeds = dgi.encoder(features, corrupt=False) embeds = dgi.encoder(features, corrupt=False)
embeds = tf.stop_gradient(embeds) embeds = tf.stop_gradient(embeds)
dur = [] dur = []
loss_fcn = tf.keras.losses.SparseCategoricalCrossentropy( loss_fcn = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True) from_logits=True
)
for epoch in range(args.n_classifier_epochs): for epoch in range(args.n_classifier_epochs):
if epoch >= 3: if epoch >= 3:
t0 = time.time() t0 = time.time()
...@@ -133,45 +150,74 @@ def main(args): ...@@ -133,45 +150,74 @@ def main(args):
# loss = loss + \ # loss = loss + \
# args.weight_decay * tf.nn.l2_loss(weight) # args.weight_decay * tf.nn.l2_loss(weight)
grads = tape.gradient(loss, classifier.trainable_weights) grads = tape.gradient(loss, classifier.trainable_weights)
classifier_optimizer.apply_gradients(zip(grads, classifier.trainable_weights)) classifier_optimizer.apply_gradients(
zip(grads, classifier.trainable_weights)
)
if epoch >= 3: if epoch >= 3:
dur.append(time.time() - t0) dur.append(time.time() - t0)
acc = evaluate(classifier, embeds, labels, val_mask) acc = evaluate(classifier, embeds, labels, val_mask)
print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | " print(
"ETputs(KTEPS) {:.2f}".format(epoch, np.mean(dur), loss.numpy().item(), "Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | "
acc, n_edges / np.mean(dur) / 1000)) "ETputs(KTEPS) {:.2f}".format(
epoch,
np.mean(dur),
loss.numpy().item(),
acc,
n_edges / np.mean(dur) / 1000,
)
)
print() print()
acc = evaluate(classifier, embeds, labels, test_mask) acc = evaluate(classifier, embeds, labels, test_mask)
print("Test Accuracy {:.4f}".format(acc)) print("Test Accuracy {:.4f}".format(acc))
if __name__ == '__main__': if __name__ == "__main__":
parser = argparse.ArgumentParser(description='DGI') parser = argparse.ArgumentParser(description="DGI")
register_data_args(parser) register_data_args(parser)
parser.add_argument("--dropout", type=float, default=0., parser.add_argument(
help="dropout probability") "--dropout", type=float, default=0.0, help="dropout probability"
parser.add_argument("--gpu", type=int, default=-1, )
help="gpu") parser.add_argument("--gpu", type=int, default=-1, help="gpu")
parser.add_argument("--dgi-lr", type=float, default=1e-3, parser.add_argument(
help="dgi learning rate") "--dgi-lr", type=float, default=1e-3, help="dgi learning rate"
parser.add_argument("--classifier-lr", type=float, default=1e-2, )
help="classifier learning rate") parser.add_argument(
parser.add_argument("--n-dgi-epochs", type=int, default=300, "--classifier-lr",
help="number of training epochs") type=float,
parser.add_argument("--n-classifier-epochs", type=int, default=300, default=1e-2,
help="number of training epochs") help="classifier learning rate",
parser.add_argument("--n-hidden", type=int, default=512, )
help="number of hidden gcn units") parser.add_argument(
parser.add_argument("--n-layers", type=int, default=1, "--n-dgi-epochs",
help="number of hidden gcn layers") type=int,
parser.add_argument("--weight-decay", type=float, default=0., default=300,
help="Weight for L2 loss") help="number of training epochs",
parser.add_argument("--patience", type=int, default=20, )
help="early stop patience condition") parser.add_argument(
parser.add_argument("--self-loop", action='store_true', "--n-classifier-epochs",
help="graph self-loop (default=False)") type=int,
default=300,
help="number of training epochs",
)
parser.add_argument(
"--n-hidden", type=int, default=512, help="number of hidden gcn units"
)
parser.add_argument(
"--n-layers", type=int, default=1, help="number of hidden gcn layers"
)
parser.add_argument(
"--weight-decay", type=float, default=0.0, help="Weight for L2 loss"
)
parser.add_argument(
"--patience", type=int, default=20, help="early stop patience condition"
)
parser.add_argument(
"--self-loop",
action="store_true",
help="graph self-loop (default=False)",
)
parser.set_defaults(self_loop=False) parser.set_defaults(self_loop=False)
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
......
...@@ -9,42 +9,72 @@ Pytorch implementation: https://github.com/Diego999/pyGAT ...@@ -9,42 +9,72 @@ Pytorch implementation: https://github.com/Diego999/pyGAT
import tensorflow as tf import tensorflow as tf
from tensorflow.keras import layers from tensorflow.keras import layers
import dgl.function as fn import dgl.function as fn
from dgl.nn import GATConv from dgl.nn import GATConv
class GAT(tf.keras.Model): class GAT(tf.keras.Model):
def __init__(self, def __init__(
g, self,
num_layers, g,
in_dim, num_layers,
num_hidden, in_dim,
num_classes, num_hidden,
heads, num_classes,
activation, heads,
feat_drop, activation,
attn_drop, feat_drop,
negative_slope, attn_drop,
residual): negative_slope,
residual,
):
super(GAT, self).__init__() super(GAT, self).__init__()
self.g = g self.g = g
self.num_layers = num_layers self.num_layers = num_layers
self.gat_layers = [] self.gat_layers = []
self.activation = activation self.activation = activation
# input projection (no residual) # input projection (no residual)
self.gat_layers.append(GATConv( self.gat_layers.append(
in_dim, num_hidden, heads[0], GATConv(
feat_drop, attn_drop, negative_slope, False, self.activation)) in_dim,
num_hidden,
heads[0],
feat_drop,
attn_drop,
negative_slope,
False,
self.activation,
)
)
# hidden layers # hidden layers
for l in range(1, num_layers): for l in range(1, num_layers):
# due to multi-head, the in_dim = num_hidden * num_heads # due to multi-head, the in_dim = num_hidden * num_heads
self.gat_layers.append(GATConv( self.gat_layers.append(
num_hidden * heads[l-1], num_hidden, heads[l], GATConv(
feat_drop, attn_drop, negative_slope, residual, self.activation)) num_hidden * heads[l - 1],
num_hidden,
heads[l],
feat_drop,
attn_drop,
negative_slope,
residual,
self.activation,
)
)
# output projection # output projection
self.gat_layers.append(GATConv( self.gat_layers.append(
num_hidden * heads[-2], num_classes, heads[-1], GATConv(
feat_drop, attn_drop, negative_slope, residual, None)) num_hidden * heads[-2],
num_classes,
heads[-1],
feat_drop,
attn_drop,
negative_slope,
residual,
None,
)
)
def call(self, inputs): def call(self, inputs):
h = inputs h = inputs
......
...@@ -11,16 +11,23 @@ Pytorch implementation: https://github.com/Diego999/pyGAT ...@@ -11,16 +11,23 @@ Pytorch implementation: https://github.com/Diego999/pyGAT
""" """
import argparse import argparse
import numpy as np
import networkx as nx
import time import time
import networkx as nx
import numpy as np
import tensorflow as tf import tensorflow as tf
import dgl
from dgl.data import register_data_args
from dgl.data import CoraGraphDataset, CiteseerGraphDataset, PubmedGraphDataset
from gat import GAT from gat import GAT
from utils import EarlyStopping from utils import EarlyStopping
import dgl
from dgl.data import (
CiteseerGraphDataset,
CoraGraphDataset,
PubmedGraphDataset,
register_data_args,
)
def accuracy(logits, labels): def accuracy(logits, labels):
indices = tf.math.argmax(logits, axis=1) indices = tf.math.argmax(logits, axis=1)
acc = tf.reduce_mean(tf.cast(indices == labels, dtype=tf.float32)) acc = tf.reduce_mean(tf.cast(indices == labels, dtype=tf.float32))
...@@ -36,14 +43,14 @@ def evaluate(model, features, labels, mask): ...@@ -36,14 +43,14 @@ def evaluate(model, features, labels, mask):
def main(args): def main(args):
# load and preprocess dataset # load and preprocess dataset
if args.dataset == 'cora': if args.dataset == "cora":
data = CoraGraphDataset() data = CoraGraphDataset()
elif args.dataset == 'citeseer': elif args.dataset == "citeseer":
data = CiteseerGraphDataset() data = CiteseerGraphDataset()
elif args.dataset == 'pubmed': elif args.dataset == "pubmed":
data = PubmedGraphDataset() data = PubmedGraphDataset()
else: else:
raise ValueError('Unknown dataset: {}'.format(args.dataset)) raise ValueError("Unknown dataset: {}".format(args.dataset))
g = data[0] g = data[0]
if args.gpu < 0: if args.gpu < 0:
...@@ -53,41 +60,48 @@ def main(args): ...@@ -53,41 +60,48 @@ def main(args):
g = g.to(device) g = g.to(device)
with tf.device(device): with tf.device(device):
features = g.ndata['feat'] features = g.ndata["feat"]
labels = g.ndata['label'] labels = g.ndata["label"]
train_mask = g.ndata['train_mask'] train_mask = g.ndata["train_mask"]
val_mask = g.ndata['val_mask'] val_mask = g.ndata["val_mask"]
test_mask = g.ndata['test_mask'] test_mask = g.ndata["test_mask"]
num_feats = features.shape[1] num_feats = features.shape[1]
n_classes = data.num_labels n_classes = data.num_labels
n_edges = data.graph.number_of_edges() n_edges = data.graph.number_of_edges()
print("""----Data statistics------' print(
"""----Data statistics------'
#Edges %d #Edges %d
#Classes %d #Classes %d
#Train samples %d #Train samples %d
#Val samples %d #Val samples %d
#Test samples %d""" % #Test samples %d"""
(n_edges, n_classes, % (
train_mask.numpy().sum(), n_edges,
val_mask.numpy().sum(), n_classes,
test_mask.numpy().sum())) train_mask.numpy().sum(),
val_mask.numpy().sum(),
test_mask.numpy().sum(),
)
)
g = dgl.remove_self_loop(g) g = dgl.remove_self_loop(g)
g = dgl.add_self_loop(g) g = dgl.add_self_loop(g)
n_edges = g.number_of_edges() n_edges = g.number_of_edges()
# create model # create model
heads = ([args.num_heads] * args.num_layers) + [args.num_out_heads] heads = ([args.num_heads] * args.num_layers) + [args.num_out_heads]
model = GAT(g, model = GAT(
args.num_layers, g,
num_feats, args.num_layers,
args.num_hidden, num_feats,
n_classes, args.num_hidden,
heads, n_classes,
tf.nn.elu, heads,
args.in_drop, tf.nn.elu,
args.attn_drop, args.in_drop,
args.negative_slope, args.attn_drop,
args.residual) args.negative_slope,
args.residual,
)
print(model) print(model)
if args.early_stop: if args.early_stop:
stopper = EarlyStopping(patience=100) stopper = EarlyStopping(patience=100)
...@@ -98,7 +112,8 @@ def main(args): ...@@ -98,7 +112,8 @@ def main(args):
# use optimizer # use optimizer
optimizer = tf.keras.optimizers.Adam( optimizer = tf.keras.optimizers.Adam(
learning_rate=args.lr, epsilon=1e-8) learning_rate=args.lr, epsilon=1e-8
)
# initialize graph # initialize graph
dur = [] dur = []
...@@ -109,15 +124,19 @@ def main(args): ...@@ -109,15 +124,19 @@ def main(args):
with tf.GradientTape() as tape: with tf.GradientTape() as tape:
tape.watch(model.trainable_weights) tape.watch(model.trainable_weights)
logits = model(features, training=True) logits = model(features, training=True)
loss_value = tf.reduce_mean(loss_fcn( loss_value = tf.reduce_mean(
labels=labels[train_mask], logits=logits[train_mask])) loss_fcn(
labels=labels[train_mask], logits=logits[train_mask]
)
)
# Manually Weight Decay # Manually Weight Decay
# We found Tensorflow has a different implementation on weight decay # We found Tensorflow has a different implementation on weight decay
# of Adam(W) optimizer with PyTorch. And this results in worse results. # of Adam(W) optimizer with PyTorch. And this results in worse results.
# Manually adding weights to the loss to do weight decay solves this problem. # Manually adding weights to the loss to do weight decay solves this problem.
for weight in model.trainable_weights: for weight in model.trainable_weights:
loss_value = loss_value + \ loss_value = loss_value + args.weight_decay * tf.nn.l2_loss(
args.weight_decay*tf.nn.l2_loss(weight) weight
)
grads = tape.gradient(loss_value, model.trainable_weights) grads = tape.gradient(loss_value, model.trainable_weights)
optimizer.apply_gradients(zip(grads, model.trainable_weights)) optimizer.apply_gradients(zip(grads, model.trainable_weights))
...@@ -135,50 +154,90 @@ def main(args): ...@@ -135,50 +154,90 @@ def main(args):
if stopper.step(val_acc, model): if stopper.step(val_acc, model):
break break
print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | TrainAcc {:.4f} |" print(
" ValAcc {:.4f} | ETputs(KTEPS) {:.2f}". "Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | TrainAcc {:.4f} |"
format(epoch, np.mean(dur), loss_value.numpy().item(), train_acc, " ValAcc {:.4f} | ETputs(KTEPS) {:.2f}".format(
val_acc, n_edges / np.mean(dur) / 1000)) epoch,
np.mean(dur),
loss_value.numpy().item(),
train_acc,
val_acc,
n_edges / np.mean(dur) / 1000,
)
)
print() print()
if args.early_stop: if args.early_stop:
model.load_weights('es_checkpoint.pb') model.load_weights("es_checkpoint.pb")
acc = evaluate(model, features, labels, test_mask) acc = evaluate(model, features, labels, test_mask)
print("Test Accuracy {:.4f}".format(acc)) print("Test Accuracy {:.4f}".format(acc))
if __name__ == '__main__': if __name__ == "__main__":
parser = argparse.ArgumentParser(description='GAT') parser = argparse.ArgumentParser(description="GAT")
register_data_args(parser) register_data_args(parser)
parser.add_argument("--gpu", type=int, default=-1, parser.add_argument(
help="which GPU to use. Set -1 to use CPU.") "--gpu",
parser.add_argument("--epochs", type=int, default=200, type=int,
help="number of training epochs") default=-1,
parser.add_argument("--num-heads", type=int, default=8, help="which GPU to use. Set -1 to use CPU.",
help="number of hidden attention heads") )
parser.add_argument("--num-out-heads", type=int, default=1, parser.add_argument(
help="number of output attention heads") "--epochs", type=int, default=200, help="number of training epochs"
parser.add_argument("--num-layers", type=int, default=1, )
help="number of hidden layers") parser.add_argument(
parser.add_argument("--num-hidden", type=int, default=8, "--num-heads",
help="number of hidden units") type=int,
parser.add_argument("--residual", action="store_true", default=False, default=8,
help="use residual connection") help="number of hidden attention heads",
parser.add_argument("--in-drop", type=float, default=.6, )
help="input feature dropout") parser.add_argument(
parser.add_argument("--attn-drop", type=float, default=.6, "--num-out-heads",
help="attention dropout") type=int,
parser.add_argument("--lr", type=float, default=0.005, default=1,
help="learning rate") help="number of output attention heads",
parser.add_argument('--weight-decay', type=float, default=5e-4, )
help="weight decay") parser.add_argument(
parser.add_argument('--negative-slope', type=float, default=0.2, "--num-layers", type=int, default=1, help="number of hidden layers"
help="the negative slope of leaky relu") )
parser.add_argument('--early-stop', action='store_true', default=False, parser.add_argument(
help="indicates whether to use early stop or not") "--num-hidden", type=int, default=8, help="number of hidden units"
parser.add_argument('--fastmode', action="store_true", default=False, )
help="skip re-evaluate the validation set") parser.add_argument(
"--residual",
action="store_true",
default=False,
help="use residual connection",
)
parser.add_argument(
"--in-drop", type=float, default=0.6, help="input feature dropout"
)
parser.add_argument(
"--attn-drop", type=float, default=0.6, help="attention dropout"
)
parser.add_argument("--lr", type=float, default=0.005, help="learning rate")
parser.add_argument(
"--weight-decay", type=float, default=5e-4, help="weight decay"
)
parser.add_argument(
"--negative-slope",
type=float,
default=0.2,
help="the negative slope of leaky relu",
)
parser.add_argument(
"--early-stop",
action="store_true",
default=False,
help="indicates whether to use early stop or not",
)
parser.add_argument(
"--fastmode",
action="store_true",
default=False,
help="skip re-evaluate the validation set",
)
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
......
import numpy as np import numpy as np
class EarlyStopping: class EarlyStopping:
def __init__(self, patience=10): def __init__(self, patience=10):
self.patience = patience self.patience = patience
...@@ -14,7 +15,9 @@ class EarlyStopping: ...@@ -14,7 +15,9 @@ class EarlyStopping:
self.save_checkpoint(model) self.save_checkpoint(model)
elif score < self.best_score: elif score < self.best_score:
self.counter += 1 self.counter += 1
print(f'EarlyStopping counter: {self.counter} out of {self.patience}') print(
f"EarlyStopping counter: {self.counter} out of {self.patience}"
)
if self.counter >= self.patience: if self.counter >= self.patience:
self.early_stop = True self.early_stop = True
else: else:
...@@ -24,5 +27,5 @@ class EarlyStopping: ...@@ -24,5 +27,5 @@ class EarlyStopping:
return self.early_stop return self.early_stop
def save_checkpoint(self, model): def save_checkpoint(self, model):
'''Saves model when validation loss decrease.''' """Saves model when validation loss decrease."""
model.save_weights('es_checkpoint.pb') model.save_weights("es_checkpoint.pb")
...@@ -7,25 +7,26 @@ References: ...@@ -7,25 +7,26 @@ References:
""" """
import tensorflow as tf import tensorflow as tf
from tensorflow.keras import layers from tensorflow.keras import layers
from dgl.nn.tensorflow import GraphConv from dgl.nn.tensorflow import GraphConv
class GCN(tf.keras.Model): class GCN(tf.keras.Model):
def __init__(self, def __init__(
g, self, g, in_feats, n_hidden, n_classes, n_layers, activation, dropout
in_feats, ):
n_hidden,
n_classes,
n_layers,
activation,
dropout):
super(GCN, self).__init__() super(GCN, self).__init__()
self.g = g self.g = g
self.layer_list = [] self.layer_list = []
# input layer # input layer
self.layer_list.append(GraphConv(in_feats, n_hidden, activation=activation)) self.layer_list.append(
GraphConv(in_feats, n_hidden, activation=activation)
)
# hidden layers # hidden layers
for i in range(n_layers - 1): for i in range(n_layers - 1):
self.layer_list.append(GraphConv(n_hidden, n_hidden, activation=activation)) self.layer_list.append(
GraphConv(n_hidden, n_hidden, activation=activation)
)
# output layer # output layer
self.layer_list.append(GraphConv(n_hidden, n_classes)) self.layer_list.append(GraphConv(n_hidden, n_classes))
self.dropout = layers.Dropout(dropout) self.dropout = layers.Dropout(dropout)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment