Unverified Commit 3bc31098 authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

use PBG sampler for testing. (#899)

parent 251a9842
...@@ -105,12 +105,12 @@ def main(args): ...@@ -105,12 +105,12 @@ def main(args):
for i in range(args.num_proc): for i in range(args.num_proc):
test_sampler_head = eval_dataset.create_sampler('test', args.batch_size, test_sampler_head = eval_dataset.create_sampler('test', args.batch_size,
args.neg_sample_size, args.neg_sample_size,
mode='head', mode='PBG-head',
num_workers=args.num_worker, num_workers=args.num_worker,
rank=i, ranks=args.num_proc) rank=i, ranks=args.num_proc)
test_sampler_tail = eval_dataset.create_sampler('test', args.batch_size, test_sampler_tail = eval_dataset.create_sampler('test', args.batch_size,
args.neg_sample_size, args.neg_sample_size,
mode='tail', mode='PBG-tail',
num_workers=args.num_worker, num_workers=args.num_worker,
rank=i, ranks=args.num_proc) rank=i, ranks=args.num_proc)
test_sampler_heads.append(test_sampler_head) test_sampler_heads.append(test_sampler_head)
...@@ -118,12 +118,12 @@ def main(args): ...@@ -118,12 +118,12 @@ def main(args):
else: else:
test_sampler_head = eval_dataset.create_sampler('test', args.batch_size, test_sampler_head = eval_dataset.create_sampler('test', args.batch_size,
args.neg_sample_size, args.neg_sample_size,
mode='head', mode='PBG-head',
num_workers=args.num_worker, num_workers=args.num_worker,
rank=0, ranks=1) rank=0, ranks=1)
test_sampler_tail = eval_dataset.create_sampler('test', args.batch_size, test_sampler_tail = eval_dataset.create_sampler('test', args.batch_size,
args.neg_sample_size, args.neg_sample_size,
mode='tail', mode='PBG-tail',
num_workers=args.num_worker, num_workers=args.num_worker,
rank=0, ranks=1) rank=0, ranks=1)
......
...@@ -208,12 +208,12 @@ def run(args, logger): ...@@ -208,12 +208,12 @@ def run(args, logger):
for i in range(args.num_proc): for i in range(args.num_proc):
test_sampler_head = eval_dataset.create_sampler('test', args.batch_size_eval, test_sampler_head = eval_dataset.create_sampler('test', args.batch_size_eval,
args.neg_sample_size_test, args.neg_sample_size_test,
mode='head', mode='PBG-head',
num_workers=args.num_worker, num_workers=args.num_worker,
rank=i, ranks=args.num_proc) rank=i, ranks=args.num_proc)
test_sampler_tail = eval_dataset.create_sampler('test', args.batch_size_eval, test_sampler_tail = eval_dataset.create_sampler('test', args.batch_size_eval,
args.neg_sample_size_test, args.neg_sample_size_test,
mode='tail', mode='PBG-tail',
num_workers=args.num_worker, num_workers=args.num_worker,
rank=i, ranks=args.num_proc) rank=i, ranks=args.num_proc)
test_sampler_heads.append(test_sampler_head) test_sampler_heads.append(test_sampler_head)
...@@ -221,12 +221,12 @@ def run(args, logger): ...@@ -221,12 +221,12 @@ def run(args, logger):
else: else:
test_sampler_head = eval_dataset.create_sampler('test', args.batch_size_eval, test_sampler_head = eval_dataset.create_sampler('test', args.batch_size_eval,
args.neg_sample_size_test, args.neg_sample_size_test,
mode='head', mode='PBG-head',
num_workers=args.num_worker, num_workers=args.num_worker,
rank=0, ranks=1) rank=0, ranks=1)
test_sampler_tail = eval_dataset.create_sampler('test', args.batch_size_eval, test_sampler_tail = eval_dataset.create_sampler('test', args.batch_size_eval,
args.neg_sample_size_test, args.neg_sample_size_test,
mode='tail', mode='PBG-tail',
num_workers=args.num_worker, num_workers=args.num_worker,
rank=0, ranks=1) rank=0, ranks=1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment