Unverified Commit 1de192f4 authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

[KG] More fixes on eval (#1168)

* remove parallel sampling for multiprocessing.

* avoid memory copy in eval.

* remove print.
parent 346bc235
...@@ -286,7 +286,6 @@ class EvalDataset(object): ...@@ -286,7 +286,6 @@ class EvalDataset(object):
beg = edges.shape[0] * rank // ranks beg = edges.shape[0] * rank // ranks
end = min(edges.shape[0] * (rank + 1) // ranks, edges.shape[0]) end = min(edges.shape[0] * (rank + 1) // ranks, edges.shape[0])
edges = edges[beg: end] edges = edges[beg: end]
print("eval on {} edges".format(len(edges)))
return EvalSampler(self.g, edges, batch_size, neg_sample_size, return EvalSampler(self.g, edges, batch_size, neg_sample_size,
mode, num_workers, filter_false_neg) mode, num_workers, filter_false_neg)
......
...@@ -102,6 +102,11 @@ def main(args): ...@@ -102,6 +102,11 @@ def main(args):
if args.neg_sample_size < 0: if args.neg_sample_size < 0:
args.neg_sample_size_test = args.neg_sample_size = eval_dataset.g.number_of_nodes() args.neg_sample_size_test = args.neg_sample_size = eval_dataset.g.number_of_nodes()
args.eval_filter = not args.no_eval_filter args.eval_filter = not args.no_eval_filter
num_workers = args.num_worker
# for multiprocessing evaluation, we don't need to sample multiple batches at a time
# in each process.
if args.num_proc > 1:
num_workers = 1
if args.num_proc > 1: if args.num_proc > 1:
test_sampler_tails = [] test_sampler_tails = []
test_sampler_heads = [] test_sampler_heads = []
...@@ -110,13 +115,13 @@ def main(args): ...@@ -110,13 +115,13 @@ def main(args):
args.neg_sample_size, args.neg_sample_size,
args.eval_filter, args.eval_filter,
mode='PBG-head', mode='PBG-head',
num_workers=args.num_worker, num_workers=num_workers,
rank=i, ranks=args.num_proc) rank=i, ranks=args.num_proc)
test_sampler_tail = eval_dataset.create_sampler('test', args.batch_size, test_sampler_tail = eval_dataset.create_sampler('test', args.batch_size,
args.neg_sample_size, args.neg_sample_size,
args.eval_filter, args.eval_filter,
mode='PBG-tail', mode='PBG-tail',
num_workers=args.num_worker, num_workers=num_workers,
rank=i, ranks=args.num_proc) rank=i, ranks=args.num_proc)
test_sampler_heads.append(test_sampler_head) test_sampler_heads.append(test_sampler_head)
test_sampler_tails.append(test_sampler_tail) test_sampler_tails.append(test_sampler_tail)
...@@ -125,13 +130,13 @@ def main(args): ...@@ -125,13 +130,13 @@ def main(args):
args.neg_sample_size, args.neg_sample_size,
args.eval_filter, args.eval_filter,
mode='PBG-head', mode='PBG-head',
num_workers=args.num_worker, num_workers=num_workers,
rank=0, ranks=1) rank=0, ranks=1)
test_sampler_tail = eval_dataset.create_sampler('test', args.batch_size, test_sampler_tail = eval_dataset.create_sampler('test', args.batch_size,
args.neg_sample_size, args.neg_sample_size,
args.eval_filter, args.eval_filter,
mode='PBG-tail', mode='PBG-tail',
num_workers=args.num_worker, num_workers=num_workers,
rank=0, ranks=1) rank=0, ranks=1)
# load model # load model
...@@ -145,6 +150,7 @@ def main(args): ...@@ -145,6 +150,7 @@ def main(args):
# test # test
args.step = 0 args.step = 0
args.max_step = 0 args.max_step = 0
start = time.time()
if args.num_proc > 1: if args.num_proc > 1:
queue = mp.Queue(args.num_proc) queue = mp.Queue(args.num_proc)
procs = [] procs = []
...@@ -168,6 +174,7 @@ def main(args): ...@@ -168,6 +174,7 @@ def main(args):
print('Test average {} at [{}/{}]: {}'.format(k, args.step, args.max_step, v)) print('Test average {} at [{}/{}]: {}'.format(k, args.step, args.max_step, v))
else: else:
test(args, model, [test_sampler_head, test_sampler_tail]) test(args, model, [test_sampler_head, test_sampler_tail])
print('Test takes {:.3f} seconds'.format(time.time() - start))
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -55,9 +55,13 @@ class ExternalEmbedding: ...@@ -55,9 +55,13 @@ class ExternalEmbedding:
s = self.emb[idx] s = self.emb[idx]
if self.gpu >= 0: if self.gpu >= 0:
s = s.cuda(self.gpu) s = s.cuda(self.gpu)
data = s.clone().detach().requires_grad_(True) # During the training, we need to trace the computation.
# In this case, we need to record the computation path and compute the gradients.
if trace: if trace:
data = s.clone().detach().requires_grad_(True)
self.trace.append((idx, data)) self.trace.append((idx, data))
else:
data = s
return data return data
def update(self): def update(self):
......
...@@ -171,6 +171,11 @@ def run(args, logger): ...@@ -171,6 +171,11 @@ def run(args, logger):
train_sampler = NewBidirectionalOneShotIterator(train_sampler_head, train_sampler_tail, train_sampler = NewBidirectionalOneShotIterator(train_sampler_head, train_sampler_tail,
True, n_entities) True, n_entities)
# for multiprocessing evaluation, we don't need to sample multiple batches at a time
# in each process.
num_workers = args.num_worker
if args.num_proc > 1:
num_workers = 1
if args.valid or args.test: if args.valid or args.test:
eval_dataset = EvalDataset(dataset, args) eval_dataset = EvalDataset(dataset, args)
if args.valid: if args.valid:
...@@ -184,13 +189,13 @@ def run(args, logger): ...@@ -184,13 +189,13 @@ def run(args, logger):
args.neg_sample_size_valid, args.neg_sample_size_valid,
args.eval_filter, args.eval_filter,
mode='PBG-head', mode='PBG-head',
num_workers=args.num_worker, num_workers=num_workers,
rank=i, ranks=args.num_proc) rank=i, ranks=args.num_proc)
valid_sampler_tail = eval_dataset.create_sampler('valid', args.batch_size_eval, valid_sampler_tail = eval_dataset.create_sampler('valid', args.batch_size_eval,
args.neg_sample_size_valid, args.neg_sample_size_valid,
args.eval_filter, args.eval_filter,
mode='PBG-tail', mode='PBG-tail',
num_workers=args.num_worker, num_workers=num_workers,
rank=i, ranks=args.num_proc) rank=i, ranks=args.num_proc)
valid_sampler_heads.append(valid_sampler_head) valid_sampler_heads.append(valid_sampler_head)
valid_sampler_tails.append(valid_sampler_tail) valid_sampler_tails.append(valid_sampler_tail)
...@@ -199,13 +204,13 @@ def run(args, logger): ...@@ -199,13 +204,13 @@ def run(args, logger):
args.neg_sample_size_valid, args.neg_sample_size_valid,
args.eval_filter, args.eval_filter,
mode='PBG-head', mode='PBG-head',
num_workers=args.num_worker, num_workers=num_workers,
rank=0, ranks=1) rank=0, ranks=1)
valid_sampler_tail = eval_dataset.create_sampler('valid', args.batch_size_eval, valid_sampler_tail = eval_dataset.create_sampler('valid', args.batch_size_eval,
args.neg_sample_size_valid, args.neg_sample_size_valid,
args.eval_filter, args.eval_filter,
mode='PBG-tail', mode='PBG-tail',
num_workers=args.num_worker, num_workers=num_workers,
rank=0, ranks=1) rank=0, ranks=1)
if args.test: if args.test:
# Here we want to use the regualr negative sampler because we need to ensure that # Here we want to use the regualr negative sampler because we need to ensure that
...@@ -218,13 +223,13 @@ def run(args, logger): ...@@ -218,13 +223,13 @@ def run(args, logger):
args.neg_sample_size_test, args.neg_sample_size_test,
args.eval_filter, args.eval_filter,
mode='PBG-head', mode='PBG-head',
num_workers=args.num_worker, num_workers=num_workers,
rank=i, ranks=args.num_proc) rank=i, ranks=args.num_proc)
test_sampler_tail = eval_dataset.create_sampler('test', args.batch_size_eval, test_sampler_tail = eval_dataset.create_sampler('test', args.batch_size_eval,
args.neg_sample_size_test, args.neg_sample_size_test,
args.eval_filter, args.eval_filter,
mode='PBG-tail', mode='PBG-tail',
num_workers=args.num_worker, num_workers=num_workers,
rank=i, ranks=args.num_proc) rank=i, ranks=args.num_proc)
test_sampler_heads.append(test_sampler_head) test_sampler_heads.append(test_sampler_head)
test_sampler_tails.append(test_sampler_tail) test_sampler_tails.append(test_sampler_tail)
...@@ -233,13 +238,13 @@ def run(args, logger): ...@@ -233,13 +238,13 @@ def run(args, logger):
args.neg_sample_size_test, args.neg_sample_size_test,
args.eval_filter, args.eval_filter,
mode='PBG-head', mode='PBG-head',
num_workers=args.num_worker, num_workers=num_workers,
rank=0, ranks=1) rank=0, ranks=1)
test_sampler_tail = eval_dataset.create_sampler('test', args.batch_size_eval, test_sampler_tail = eval_dataset.create_sampler('test', args.batch_size_eval,
args.neg_sample_size_test, args.neg_sample_size_test,
args.eval_filter, args.eval_filter,
mode='PBG-tail', mode='PBG-tail',
num_workers=args.num_worker, num_workers=num_workers,
rank=0, ranks=1) rank=0, ranks=1)
# We need to free all memory referenced by dataset. # We need to free all memory referenced by dataset.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment