Unverified Commit e6584043 authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

[KG] make batch size compatible with neg sample size (#1343)

* adjust batch size.

* fix for eval.py

* adjust parameters in distributed training.

* move code.
parent 7b3a7b14
......@@ -102,13 +102,13 @@ DGLBACKEND=pytorch python3 train.py --model RotatE --dataset wn18 --batch_size 1
DGLBACKEND=pytorch python3 train.py --model ComplEx --dataset Freebase --batch_size 1024 \
--neg_sample_size 256 --hidden_dim 400 --gamma 500.0 --lr 0.1 --max_step 50000 \
--batch_size_eval 128 --test -adv --eval_interval 300000 --num_thread 1 \
--neg_sample_size_test 100000 --eval_percent 0.02 --num_proc 48
--neg_sample_size_eval 100000 --eval_percent 0.02 --num_proc 48
# Freebase multi-gpu
# TransE_l2 8 GPU
DGLBACKEND=pytorch python3 train.py --model TransE_l2 --dataset Freebase --batch_size 1024 \
--neg_sample_size 256 --hidden_dim 400 --gamma 10 --lr 0.1 --batch_size_eval 1000 \
--valid --test -adv --mix_cpu_gpu --neg_deg_sample_eval --neg_sample_size_test 1000 \
--valid --test -adv --mix_cpu_gpu --neg_deg_sample_eval --neg_sample_size_eval 1000 \
--num_proc 8 --gpu 0 1 2 3 4 5 6 7 --num_thread 4 --regularization_coef 1e-9 \
--no_eval_filter --max_step 400000 --rel_part --eval_interval 100000 --log_interval 10000 \
--no_eval_filter --async_update --neg_deg_sample --force_sync_interval 1000
......@@ -116,7 +116,7 @@ DGLBACKEND=pytorch python3 train.py --model TransE_l2 --dataset Freebase --batch
# TransE_l2 16 GPU
DGLBACKEND=pytorch python3 train.py --model TransE_l2 --dataset Freebase --batch_size 1024 \
--neg_sample_size 256 --hidden_dim 400 --gamma 10 --lr 0.1 --batch_size_eval 1000 \
--valid --test -adv --mix_cpu_gpu --neg_deg_sample_eval --neg_sample_size_test 1000 \
--valid --test -adv --mix_cpu_gpu --neg_deg_sample_eval --neg_sample_size_eval 1000 \
--num_proc 16 --gpu 0 1 2 3 4 5 6 7 --num_thread 4 --regularization_coef 1e-9 \
--no_eval_filter --max_step 200000 --soft_rel_part --eval_interval 100000 --log_interval 10000 \
--no_eval_filter --async_update --neg_deg_sample --force_sync_interval 1000
......@@ -434,17 +434,15 @@ def create_neg_subgraph(pos_g, neg_g, chunk_size, neg_sample_size, is_chunked,
# We use all nodes to create negative edges. Regardless of the sampling algorithm,
# we can always view the subgraph with one chunk.
if (neg_head and len(neg_g.head_nid) == num_nodes) \
or (not neg_head and len(neg_g.tail_nid) == num_nodes):
or (not neg_head and len(neg_g.tail_nid) == num_nodes) \
or pos_g.number_of_edges() < chunk_size:
num_chunks = 1
chunk_size = pos_g.number_of_edges()
elif is_chunked:
if pos_g.number_of_edges() < chunk_size:
# This is probably the last batch. Let's ignore it.
if pos_g.number_of_edges() % chunk_size > 0:
return None
else:
# This is probably the last batch. Let's ignore it.
if pos_g.number_of_edges() % chunk_size > 0:
return None
num_chunks = int(pos_g.number_of_edges() / chunk_size)
num_chunks = int(pos_g.number_of_edges() / chunk_size)
assert num_chunks * chunk_size == pos_g.number_of_edges()
else:
num_chunks = pos_g.number_of_edges()
......
......@@ -26,4 +26,4 @@ done
##################################################################################
MKL_NUM_THREADS=1 OMP_NUM_THREADS=1 DGLBACKEND=pytorch python3 ../kvclient.py --model ComplEx --dataset Freebase \
--batch_size 1024 --neg_sample_size 256 --hidden_dim 400 --gamma 143.0 --lr 0.1 --max_step 12500 --log_interval 100 \
--batch_size_eval 1000 --neg_sample_size_test 1000 --test -adv --total_machine 4 --num_thread 1 --num_client 40
\ No newline at end of file
--batch_size_eval 1000 --neg_sample_size_eval 1000 --test -adv --total_machine 4 --num_thread 1 --num_client 40
......@@ -26,4 +26,4 @@ done
##################################################################################
MKL_NUM_THREADS=1 OMP_NUM_THREADS=1 DGLBACKEND=pytorch python3 ../kvclient.py --model DistMult --dataset Freebase \
--batch_size 1024 --neg_sample_size 256 --hidden_dim 400 --gamma 143.0 --lr 0.08 --max_step 12500 --log_interval 100 \
--batch_size_eval 1000 --neg_sample_size_test 1000 --test -adv --total_machine 4 --num_thread 1 --num_client 40
\ No newline at end of file
--batch_size_eval 1000 --neg_sample_size_eval 1000 --test -adv --total_machine 4 --num_thread 1 --num_client 40
......@@ -26,4 +26,4 @@ done
##################################################################################
MKL_NUM_THREADS=1 OMP_NUM_THREADS=1 DGLBACKEND=pytorch python3 ../kvclient.py --model TransE_l2 --dataset Freebase \
--batch_size 1000 --neg_sample_size 200 --hidden_dim 400 --gamma 10 --lr 0.1 --max_step 12500 --log_interval 100 --num_thread 1 \
--batch_size_eval 1000 --neg_sample_size_test 1000 --test -adv --regularization_coef 1e-9 --total_machine 4 --num_client 40
\ No newline at end of file
--batch_size_eval 1000 --neg_sample_size_eval 1000 --test -adv --regularization_coef 1e-9 --total_machine 4 --num_client 40
......@@ -7,6 +7,8 @@ import logging
import time
import pickle
from utils import get_compatible_batch_size
backend = os.environ.get('DGLBACKEND', 'pytorch')
if backend.lower() == 'mxnet':
import multiprocessing as mp
......@@ -38,7 +40,7 @@ class ArgParser(argparse.ArgumentParser):
help='the place where models are saved')
self.add_argument('--batch_size_eval', type=int, default=8,
help='batch size used for eval and test')
self.add_argument('--neg_sample_size_test', type=int, default=-1,
self.add_argument('--neg_sample_size_eval', type=int, default=-1,
help='negative sampling size for testing')
self.add_argument('--neg_deg_sample_eval', action='store_true',
help='negative sampling proportional to vertex degree for testing')
......@@ -85,6 +87,7 @@ def get_logger(args):
print("Logs are being recorded at: {}".format(log_file))
return logger
def main(args):
args.eval_filter = not args.no_eval_filter
if args.neg_deg_sample_eval:
......@@ -105,8 +108,9 @@ def main(args):
# all positive edges are excluded.
eval_dataset = EvalDataset(dataset, args)
if args.neg_sample_size_test < 0:
args.neg_sample_size_test = args.neg_sample_size = eval_dataset.g.number_of_nodes()
if args.neg_sample_size_eval < 0:
args.neg_sample_size_eval = args.neg_sample_size = eval_dataset.g.number_of_nodes()
args.batch_size_eval = get_compatible_batch_size(args.batch_size_eval, args.neg_sample_size_eval)
args.num_workers = 8 # fix num_workers to 8
if args.num_proc > 1:
......@@ -114,15 +118,15 @@ def main(args):
test_sampler_heads = []
for i in range(args.num_proc):
test_sampler_head = eval_dataset.create_sampler('test', args.batch_size_eval,
args.neg_sample_size_test,
args.neg_sample_size_test,
args.neg_sample_size_eval,
args.neg_sample_size_eval,
args.eval_filter,
mode='chunk-head',
num_workers=args.num_workers,
rank=i, ranks=args.num_proc)
test_sampler_tail = eval_dataset.create_sampler('test', args.batch_size_eval,
args.neg_sample_size_test,
args.neg_sample_size_test,
args.neg_sample_size_eval,
args.neg_sample_size_eval,
args.eval_filter,
mode='chunk-tail',
num_workers=args.num_workers,
......@@ -131,15 +135,15 @@ def main(args):
test_sampler_tails.append(test_sampler_tail)
else:
test_sampler_head = eval_dataset.create_sampler('test', args.batch_size_eval,
args.neg_sample_size_test,
args.neg_sample_size_test,
args.neg_sample_size_eval,
args.neg_sample_size_eval,
args.eval_filter,
mode='chunk-head',
num_workers=args.num_workers,
rank=0, ranks=1)
test_sampler_tail = eval_dataset.create_sampler('test', args.batch_size_eval,
args.neg_sample_size_test,
args.neg_sample_size_test,
args.neg_sample_size_eval,
args.neg_sample_size_eval,
args.eval_filter,
mode='chunk-tail',
num_workers=args.num_workers,
......
......@@ -10,6 +10,7 @@ if os.name != 'nt':
import torch.multiprocessing as mp
from train_pytorch import load_model, dist_train_test
from utils import get_compatible_batch_size
from train import get_logger
from dataloader import TrainDataset, NewBidirectionalOneShotIterator
......@@ -51,10 +52,8 @@ class ArgParser(argparse.ArgumentParser):
help='negative sample proportional to vertex degree in the training')
self.add_argument('--neg_deg_sample_eval', action='store_true',
help='negative sampling proportional to vertex degree in the evaluation')
self.add_argument('--neg_sample_size_valid', type=int, default=1000,
help='negative sampling size for validation')
self.add_argument('--neg_sample_size_test', type=int, default=-1,
help='negative sampling size for testing')
self.add_argument('--neg_sample_size_eval', type=int, default=-1,
help='negative sampling size for evaluation')
self.add_argument('--hidden_dim', type=int, default=256,
help='hidden dim used by relation and entity')
self.add_argument('--lr', type=float, default=0.0001,
......@@ -204,6 +203,11 @@ def start_worker(args, logger):
args.strict_rel_part = args.mix_cpu_gpu and (train_data.cross_part == False)
args.soft_rel_part = args.mix_cpu_gpu and args.soft_rel_part and train_data.cross_part
if args.neg_sample_size_eval < 0:
args.neg_sample_size_eval = dataset.n_entities
args.batch_size = get_compatible_batch_size(args.batch_size, args.neg_sample_size)
args.batch_size_eval = get_compatible_batch_size(args.batch_size_eval, args.neg_sample_size_eval)
args.num_workers = 8 # fix num_workers to 8
train_samplers = []
for i in range(args.num_client):
......@@ -259,4 +263,4 @@ def start_worker(args, logger):
if __name__ == '__main__':
args = ArgParser().parse_args()
logger = get_logger(args)
start_worker(args, logger)
\ No newline at end of file
start_worker(args, logger)
......@@ -7,6 +7,8 @@ import logging
import time
import json
from utils import get_compatible_batch_size
backend = os.environ.get('DGLBACKEND', 'pytorch')
if backend.lower() == 'mxnet':
import multiprocessing as mp
......@@ -52,10 +54,8 @@ class ArgParser(argparse.ArgumentParser):
help='negative sample proportional to vertex degree in the training')
self.add_argument('--neg_deg_sample_eval', action='store_true',
help='negative sampling proportional to vertex degree in the evaluation')
self.add_argument('--neg_sample_size_valid', type=int, default=1000,
help='negative sampling size for validation')
self.add_argument('--neg_sample_size_test', type=int, default=-1,
help='negative sampling size for testing')
self.add_argument('--neg_sample_size_eval', type=int, default=-1,
help='negative sampling size for evaluation')
self.add_argument('--eval_percent', type=float, default=1,
help='sample some percentage for evaluation.')
self.add_argument('--hidden_dim', type=int, default=256,
......@@ -139,8 +139,10 @@ def run(args, logger):
# load dataset and samplers
dataset = get_dataset(args.data_path, args.dataset, args.format, args.data_files)
if args.neg_sample_size_test < 0:
args.neg_sample_size_test = dataset.n_entities
if args.neg_sample_size_eval < 0:
args.neg_sample_size_eval = dataset.n_entities
args.batch_size = get_compatible_batch_size(args.batch_size, args.neg_sample_size)
args.batch_size_eval = get_compatible_batch_size(args.batch_size_eval, args.neg_sample_size_eval)
args.eval_filter = not args.no_eval_filter
if args.neg_deg_sample_eval:
......@@ -211,15 +213,15 @@ def run(args, logger):
valid_sampler_tails = []
for i in range(args.num_proc):
valid_sampler_head = eval_dataset.create_sampler('valid', args.batch_size_eval,
args.neg_sample_size_valid,
args.neg_sample_size_valid,
args.neg_sample_size_eval,
args.neg_sample_size_eval,
args.eval_filter,
mode='chunk-head',
num_workers=args.num_workers,
rank=i, ranks=args.num_proc)
valid_sampler_tail = eval_dataset.create_sampler('valid', args.batch_size_eval,
args.neg_sample_size_valid,
args.neg_sample_size_valid,
args.neg_sample_size_eval,
args.neg_sample_size_eval,
args.eval_filter,
mode='chunk-tail',
num_workers=args.num_workers,
......@@ -228,15 +230,15 @@ def run(args, logger):
valid_sampler_tails.append(valid_sampler_tail)
else: # This is used for debug
valid_sampler_head = eval_dataset.create_sampler('valid', args.batch_size_eval,
args.neg_sample_size_valid,
args.neg_sample_size_valid,
args.neg_sample_size_eval,
args.neg_sample_size_eval,
args.eval_filter,
mode='chunk-head',
num_workers=args.num_workers,
rank=0, ranks=1)
valid_sampler_tail = eval_dataset.create_sampler('valid', args.batch_size_eval,
args.neg_sample_size_valid,
args.neg_sample_size_valid,
args.neg_sample_size_eval,
args.neg_sample_size_eval,
args.eval_filter,
mode='chunk-tail',
num_workers=args.num_workers,
......@@ -247,15 +249,15 @@ def run(args, logger):
test_sampler_heads = []
for i in range(args.num_test_proc):
test_sampler_head = eval_dataset.create_sampler('test', args.batch_size_eval,
args.neg_sample_size_test,
args.neg_sample_size_test,
args.neg_sample_size_eval,
args.neg_sample_size_eval,
args.eval_filter,
mode='chunk-head',
num_workers=args.num_workers,
rank=i, ranks=args.num_test_proc)
test_sampler_tail = eval_dataset.create_sampler('test', args.batch_size_eval,
args.neg_sample_size_test,
args.neg_sample_size_test,
args.neg_sample_size_eval,
args.neg_sample_size_eval,
args.eval_filter,
mode='chunk-tail',
num_workers=args.num_workers,
......@@ -264,15 +266,15 @@ def run(args, logger):
test_sampler_tails.append(test_sampler_tail)
else:
test_sampler_head = eval_dataset.create_sampler('test', args.batch_size_eval,
args.neg_sample_size_test,
args.neg_sample_size_test,
args.neg_sample_size_eval,
args.neg_sample_size_eval,
args.eval_filter,
mode='chunk-head',
num_workers=args.num_workers,
rank=0, ranks=1)
test_sampler_tail = eval_dataset.create_sampler('test', args.batch_size_eval,
args.neg_sample_size_test,
args.neg_sample_size_test,
args.neg_sample_size_eval,
args.neg_sample_size_eval,
args.eval_filter,
mode='chunk-tail',
num_workers=args.num_workers,
......@@ -372,7 +374,7 @@ def run(args, logger):
proc.join()
else:
test(args, model, [test_sampler_head, test_sampler_tail])
print('test:', time.time() - start)
print('testing takes {:.3f} seconds'.format(time.time() - start))
if __name__ == '__main__':
args = ArgParser().parse_args()
......
......@@ -164,7 +164,7 @@ def train(args, model, train_sampler, valid_samplers=None, rank=0, rel_parts=Non
if barrier is not None:
barrier.wait()
test(args, model, valid_samplers, rank, mode='Valid')
print('test:', time.time() - valid_start)
print('validation take {:.3f} seconds:'.format(time.time() - valid_start))
if args.soft_rel_part:
model.prepare_cross_rels(cross_rels)
if barrier is not None:
......@@ -243,8 +243,8 @@ def dist_train_test(args, model, train_sampler, entity_pb, relation_pb, l2g, ran
if args.test:
model_test.share_memory()
if args.neg_sample_size_test < 0:
args.neg_sample_size_test = dataset_full.n_entities
if args.neg_sample_size_eval < 0:
args.neg_sample_size_eval = dataset_full.n_entities
args.eval_filter = not args.no_eval_filter
if args.neg_deg_sample_eval:
assert not args.eval_filter, "if negative sampling based on degree, we can't filter positive edges."
......@@ -285,15 +285,15 @@ def dist_train_test(args, model, train_sampler, entity_pb, relation_pb, l2g, ran
test_sampler_heads = []
for i in range(args.num_test_proc):
test_sampler_head = eval_dataset.create_sampler('test', args.batch_size_eval,
args.neg_sample_size_test,
args.neg_sample_size_test,
args.neg_sample_size_eval,
args.neg_sample_size_eval,
args.eval_filter,
mode='chunk-head',
num_workers=args.num_workers,
rank=i, ranks=args.num_test_proc)
test_sampler_tail = eval_dataset.create_sampler('test', args.batch_size_eval,
args.neg_sample_size_test,
args.neg_sample_size_test,
args.neg_sample_size_eval,
args.neg_sample_size_eval,
args.eval_filter,
mode='chunk-tail',
num_workers=args.num_workers,
......@@ -334,4 +334,4 @@ def dist_train_test(args, model, train_sampler, entity_pb, relation_pb, l2g, ran
proc.join()
if client.get_id() == 0:
client.shut_down()
\ No newline at end of file
client.shut_down()
import math
def get_compatible_batch_size(batch_size, neg_sample_size):
if neg_sample_size < batch_size and batch_size % neg_sample_size != 0:
old_batch_size = batch_size
batch_size = int(math.ceil(batch_size / neg_sample_size) * neg_sample_size)
print('batch size ({}) is incompatible to the negative sample size ({}). Change the batch size to {}'.format(
old_batch_size, neg_sample_size, batch_size))
return batch_size
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment