Unverified Commit bb6a6476 authored by xiang song(charlie.song)'s avatar xiang song(charlie.song) Committed by GitHub
Browse files

[Feature][KG] Multi-GPU training support for DGL KGE (#1178)

* multi-gpu

* Pytorch can run but test has acc problem

* pytorch train/eval can run in multi-gpu

* Fix eval

* Fix

* Fix mxnet

* trigger

* triger

* Fix mxnet score_func

* Fix

* check

* FIx default arg

* Fix train_mxnet mix_cpu_gpu

* Make relation mix_cpu_gpu

* delete some dead code

* some opt for update

* Fix cpu grad update
parent f1420d19
......@@ -21,26 +21,31 @@ def RelationPartition(edges, n):
edge_cnts = np.zeros(shape=(n,), dtype=np.int64)
rel_cnts = np.zeros(shape=(n,), dtype=np.int64)
rel_dict = {}
rel_parts = []
for _ in range(n):
rel_parts.append([])
for i in range(len(cnts)):
cnt = cnts[i]
r = uniq[i]
idx = np.argmin(edge_cnts)
rel_dict[r] = idx
rel_parts[idx].append(r)
edge_cnts[idx] += cnt
rel_cnts[idx] += 1
for i, edge_cnt in enumerate(edge_cnts):
print('part {} has {} edges and {} relations'.format(i, edge_cnt, rel_cnts[i]))
parts = []
for _ in range(n):
for i in range(n):
parts.append([])
rel_parts[i] = np.array(rel_parts[i])
# let's store the edge index to each partition first.
for i, r in enumerate(rels):
part_idx = rel_dict[r]
parts[part_idx].append(i)
for i, part in enumerate(parts):
parts[i] = np.array(part, dtype=np.int64)
return parts
return parts, rel_parts
def RandomPartition(edges, n):
heads, rels, tails = edges
......@@ -79,7 +84,7 @@ class TrainDataset(object):
num_train = len(triples[0])
print('|Train|:', num_train)
if ranks > 1 and args.rel_part:
self.edge_parts = RelationPartition(triples, ranks)
self.edge_parts, self.rel_parts = RelationPartition(triples, ranks)
elif ranks > 1:
self.edge_parts = RandomPartition(triples, ranks)
else:
......
......@@ -51,8 +51,8 @@ class ArgParser(argparse.ArgumentParser):
self.add_argument('--no_eval_filter', action='store_true',
help='do not filter positive edges among negative edges for evaluation')
self.add_argument('--gpu', type=int, default=-1,
help='use GPU')
self.add_argument('--gpu', type=int, default=[-1], nargs='+',
help='a list of active gpu ids, e.g. 0')
self.add_argument('--mix_cpu_gpu', action='store_true',
help='mix CPU and GPU training')
self.add_argument('-de', '--double_ent', action='store_true',
......@@ -100,6 +100,7 @@ def main(args):
args.train = False
args.valid = False
args.test = True
args.rel_part = False
args.batch_size_eval = args.batch_size
logger = get_logger(args)
......@@ -172,7 +173,7 @@ def main(args):
procs = []
for i in range(args.num_proc):
proc = mp.Process(target=test, args=(args, model, [test_sampler_heads[i], test_sampler_tails[i]],
'Test', queue))
i, 'Test', queue))
procs.append(proc)
proc.start()
for proc in procs:
......
......@@ -28,6 +28,7 @@ class KEModel(object):
super(KEModel, self).__init__()
self.args = args
self.n_entities = n_entities
self.n_relations = n_relations
self.model_name = model_name
self.hidden_dim = hidden_dim
self.eps = 2.0
......@@ -44,7 +45,9 @@ class KEModel(object):
rel_dim = relation_dim * entity_dim
else:
rel_dim = relation_dim
self.relation_emb = ExternalEmbedding(args, n_relations, rel_dim, device)
self.rel_dim = rel_dim
self.relation_emb = ExternalEmbedding(args, n_relations, rel_dim, F.cpu() if args.mix_cpu_gpu else device)
if model_name == 'TransE' or model_name == 'TransE_l2':
self.score_func = TransEScore(gamma, 'l2')
......@@ -87,8 +90,8 @@ class KEModel(object):
def reset_parameters(self):
self.entity_emb.init(self.emb_init)
self.relation_emb.init(self.emb_init)
self.score_func.reset_parameters()
self.relation_emb.init(self.emb_init)
def predict_score(self, g):
self.score_func(g)
......@@ -174,8 +177,8 @@ class KEModel(object):
# We need to filter the positive edges in the negative graph.
if self.args.eval_filter:
filter_bias = reshape(neg_g.edata['bias'], batch_size, -1)
if self.args.gpu >= 0:
filter_bias = cuda(filter_bias, self.args.gpu)
if gpu_id >= 0:
filter_bias = cuda(filter_bias, gpu_id)
neg_scores += filter_bias
# To compute the rank of a positive edge among all negative edges,
# we need to know how many negative edges have higher scores than
......@@ -244,7 +247,8 @@ class KEModel(object):
return loss, log
def update(self):
self.entity_emb.update()
self.relation_emb.update()
self.score_func.update()
def update(self, gpu_id=-1):
self.entity_emb.update(gpu_id)
self.relation_emb.update(gpu_id)
self.score_func.update(gpu_id)
......@@ -46,7 +46,7 @@ class TransEScore(nn.Block):
return head, tail
return fn
def update(self):
def update(self, gpu_id=-1):
pass
def reset_parameters(self):
......@@ -171,8 +171,8 @@ class TransRScore(nn.Block):
def reset_parameters(self):
self.projection_emb.init(1.0)
def update(self):
self.projection_emb.update()
def update(self, gpu_id=-1):
self.projection_emb.update(gpu_id)
def save(self, path, name):
self.projection_emb.save(path, name+'projection')
......@@ -219,7 +219,7 @@ class DistMultScore(nn.Block):
return head, tail
return fn
def update(self):
def update(self, gpu_id=-1):
pass
def reset_parameters(self):
......@@ -276,7 +276,7 @@ class ComplExScore(nn.Block):
return head, tail
return fn
def update(self):
def update(self, gpu_id=-1):
pass
def reset_parameters(self):
......@@ -344,7 +344,7 @@ class RESCALScore(nn.Block):
return head, tail
return fn
def update(self):
def update(self, gpu_id=-1):
pass
def reset_parameters(self):
......@@ -412,7 +412,7 @@ class RotatEScore(nn.Block):
return head, tail
return fn
def update(self):
def update(self, gpu_id=-1):
pass
def reset_parameters(self):
......
......@@ -12,7 +12,7 @@ def logsigmoid(val):
z = nd.exp(-max_elem) + nd.exp(-val - max_elem)
return -(max_elem + nd.log(z))
get_device = lambda args : mx.gpu(args.gpu) if args.gpu >= 0 else mx.cpu()
get_device = lambda args : mx.gpu(args.gpu[0]) if args.gpu[0] >= 0 else mx.cpu()
norm = lambda x, p: nd.sum(nd.abs(x) ** p)
get_scalar = lambda x: x.detach().asscalar()
......@@ -44,14 +44,14 @@ class ExternalEmbedding:
if self.emb.context != idx.context:
idx = idx.as_in_context(self.emb.context)
data = nd.take(self.emb, idx)
if self.gpu >= 0:
data = data.as_in_context(mx.gpu(self.gpu))
if gpu_id >= 0:
data = data.as_in_context(mx.gpu(gpu_id))
data.attach_grad()
if trace:
self.trace.append((idx, data))
return data
def update(self):
def update(self, gpu_id=-1):
self.state_step += 1
for idx, data in self.trace:
grad = data.grad
......@@ -71,9 +71,9 @@ class ExternalEmbedding:
grad_sum = grad_sum.as_in_context(ctx)
self.state_sum[grad_indices] += grad_sum
std = self.state_sum[grad_indices] # _sparse_mask
if gpu_id >= 0:
std = std.as_in_context(mx.gpu(gpu_id))
std_values = nd.expand_dims(nd.sqrt(std) + 1e-10, 1)
if self.gpu >= 0:
std_values = std_values.as_in_context(mx.gpu(self.args.gpu))
tmp = (-clr * grad_values / std_values)
if tmp.context != ctx:
tmp = tmp.as_in_context(ctx)
......
......@@ -47,7 +47,7 @@ class TransEScore(nn.Module):
def forward(self, g):
g.apply_edges(lambda edges: self.edge_func(edges))
def update(self):
def update(self, gpu_id=-1):
pass
def reset_parameters(self):
......@@ -138,8 +138,8 @@ class TransRScore(nn.Module):
def reset_parameters(self):
self.projection_emb.init(1.0)
def update(self):
self.projection_emb.update()
def update(self, gpu_id=-1):
self.projection_emb.update(gpu_id)
def save(self, path, name):
self.projection_emb.save(path, name+'projection')
......@@ -186,7 +186,7 @@ class DistMultScore(nn.Module):
return head, tail
return fn
def update(self):
def update(self, gpu_id=-1):
pass
def reset_parameters(self):
......@@ -243,7 +243,7 @@ class ComplExScore(nn.Module):
return head, tail
return fn
def update(self):
def update(self, gpu_id=-1):
pass
def reset_parameters(self):
......@@ -314,7 +314,7 @@ class RESCALScore(nn.Module):
return head, tail
return fn
def update(self):
def update(self, gpu_id=-1):
pass
def reset_parameters(self):
......@@ -373,7 +373,7 @@ class RotatEScore(nn.Module):
score = score.norm(dim=0)
return {'score': self.gamma - score.sum(-1)}
def update(self):
def update(self, gpu_id=-1):
pass
def reset_parameters(self):
......
......@@ -23,7 +23,7 @@ from .. import *
logsigmoid = functional.logsigmoid
def get_device(args):
return th.device('cpu') if args.gpu < 0 else th.device('cuda:' + str(args.gpu))
return th.device('cpu') if args.gpu[0] < 0 else th.device('cuda:' + str(args.gpu[0]))
norm = lambda x, p: x.norm(p=p)**p
......@@ -53,8 +53,8 @@ class ExternalEmbedding:
def __call__(self, idx, gpu_id=-1, trace=True):
s = self.emb[idx]
if self.gpu >= 0:
s = s.cuda(self.gpu)
if gpu_id >= 0:
s = s.cuda(gpu_id)
# During the training, we need to trace the computation.
# In this case, we need to record the computation path and compute the gradients.
if trace:
......@@ -64,7 +64,7 @@ class ExternalEmbedding:
data = s
return data
def update(self):
def update(self, gpu_id=-1):
self.state_step += 1
with th.no_grad():
for idx, data in self.trace:
......@@ -85,9 +85,9 @@ class ExternalEmbedding:
grad_sum = grad_sum.to(device)
self.state_sum.index_add_(0, grad_indices, grad_sum)
std = self.state_sum[grad_indices] # _sparse_mask
if gpu_id >= 0:
std = std.cuda(gpu_id)
std_values = std.sqrt_().add_(1e-10).unsqueeze(1)
if self.gpu >= 0:
std_values = std_values.cuda(self.args.gpu)
tmp = (-clr * grad_values / std_values)
if tmp.device != device:
tmp = tmp.to(device)
......
......@@ -72,8 +72,8 @@ class ArgParser(argparse.ArgumentParser):
self.add_argument('--no_eval_filter', action='store_true',
help='do not filter positive edges among negative edges for evaluation')
self.add_argument('--gpu', type=int, default=-1,
help='use GPU')
self.add_argument('--gpu', type=int, default=[-1], nargs='+',
help='a list of active gpu ids, e.g. 0 1 2 4')
self.add_argument('--mix_cpu_gpu', action='store_true',
help='mix CPU and GPU training')
self.add_argument('-de', '--double_ent', action='store_true',
......@@ -298,8 +298,9 @@ def run(args, logger):
if args.num_proc > 1:
procs = []
for i in range(args.num_proc):
rel_parts = train_data.rel_parts if args.rel_part else None
valid_samplers = [valid_sampler_heads[i], valid_sampler_tails[i]] if args.valid else None
proc = mp.Process(target=train, args=(args, model, train_samplers[i], valid_samplers))
proc = mp.Process(target=train, args=(args, model, train_samplers[i], i, rel_parts, valid_samplers))
procs.append(proc)
proc.start()
for proc in procs:
......@@ -322,7 +323,7 @@ def run(args, logger):
procs = []
for i in range(args.num_proc):
proc = mp.Process(target=test, args=(args, model, [test_sampler_heads[i], test_sampler_tails[i]],
'Test', queue))
i, 'Test', queue))
procs.append(proc)
proc.start()
......
......@@ -14,11 +14,7 @@ def load_model(logger, args, n_entities, n_relations, ckpt=None):
args.hidden_dim, args.gamma,
double_entity_emb=args.double_ent, double_relation_emb=args.double_rel)
if ckpt is not None:
# TODO: loading model emb only work for genernal Embedding, not for ExternalEmbedding
if args.gpu >= 0:
model.load_parameters(ckpt, ctx=mx.gpu(args.gpu))
else:
model.load_parameters(ckpt, ctx=mx.cpu())
assert False, "We do not support loading model emb for genernal Embedding"
logger.info('Load model {}'.format(args.model_name))
return model
......@@ -28,23 +24,28 @@ def load_model_from_checkpoint(logger, args, n_entities, n_relations, ckpt_path)
model.load_emb(ckpt_path, args.dataset)
return model
def train(args, model, train_sampler, valid_samplers=None):
if args.num_proc > 1:
os.environ['OMP_NUM_THREADS'] = '1'
def train(args, model, train_sampler, rank=0, rel_parts=None, valid_samplers=None):
assert args.num_proc == 1, "MXNet KGE does not support multi-process now"
assert args.rel_part == False, "No need for relation partition in single process for MXNet KGE"
logs = []
for arg in vars(args):
logging.info('{:20}:{}'.format(arg, getattr(args, arg)))
if len(args.gpu) > 0:
gpu_id = args.gpu[rank % len(args.gpu)] if args.mix_cpu_gpu and args.num_proc > 1 else args.gpu[0]
else:
gpu_id = -1
start = time.time()
for step in range(args.init_step, args.max_step):
pos_g, neg_g = next(train_sampler)
args.step = step
with mx.autograd.record():
loss, log = model.forward(pos_g, neg_g, args.gpu)
loss, log = model.forward(pos_g, neg_g, gpu_id)
loss.backward()
logs.append(log)
model.update()
model.update(gpu_id)
if step % args.log_interval == 0:
for k in logs[0].keys():
......@@ -61,14 +62,20 @@ def train(args, model, train_sampler, valid_samplers=None):
# clear cache
logs = []
def test(args, model, test_samplers, mode='Test', queue=None):
def test(args, model, test_samplers, rank=0, mode='Test', queue=None):
assert args.num_proc == 1, "MXNet KGE does not support multi-process now"
logs = []
if len(args.gpu) > 0:
gpu_id = args.gpu[rank % len(args.gpu)] if args.mix_cpu_gpu and args.num_proc > 1 else args.gpu[0]
else:
gpu_id = -1
for sampler in test_samplers:
#print('Number of tests: ' + len(sampler))
count = 0
for pos_g, neg_g in sampler:
model.forward_test(pos_g, neg_g, logs, args.gpu)
model.forward_test(pos_g, neg_g, logs, gpu_id)
metrics = {}
if len(logs) > 0:
......
......@@ -4,6 +4,8 @@ from torch.utils.data import DataLoader
import torch.optim as optim
import torch as th
import torch.multiprocessing as mp
from torch.multiprocessing import Queue
from _thread import start_new_thread
from distutils.version import LooseVersion
TH_VERSION = LooseVersion(th.__version__)
......@@ -13,14 +15,36 @@ if TH_VERSION.version[0] == 1 and TH_VERSION.version[1] < 2:
import os
import logging
import time
from functools import wraps
def thread_wrapped_func(func):
@wraps(func)
def decorated_function(*args, **kwargs):
queue = Queue()
def _queue_result():
exception, trace, res = None, None, None
try:
res = func(*args, **kwargs)
except Exception as e:
exception = e
trace = traceback.format_exc()
queue.put((res, exception, trace))
start_new_thread(_queue_result, ())
result, exception, trace = queue.get()
if exception is None:
return result
else:
assert isinstance(exception, Exception)
raise exception.__class__(trace)
return decorated_function
def load_model(logger, args, n_entities, n_relations, ckpt=None):
model = KEModel(args, args.model_name, n_entities, n_relations,
args.hidden_dim, args.gamma,
double_entity_emb=args.double_ent, double_relation_emb=args.double_rel)
if ckpt is not None:
# TODO: loading model emb only work for genernal Embedding, not for ExternalEmbedding
model.load_state_dict(ckpt['model_state_dict'])
assert False, "We do not support loading model emb for genernal Embedding"
return model
......@@ -29,13 +53,19 @@ def load_model_from_checkpoint(logger, args, n_entities, n_relations, ckpt_path)
model.load_emb(ckpt_path, args.dataset)
return model
def train(args, model, train_sampler, valid_samplers=None):
@thread_wrapped_func
def train(args, model, train_sampler, rank=0, rel_parts=None, valid_samplers=None):
if args.num_proc > 1:
th.set_num_threads(1)
th.set_num_threads(4)
logs = []
for arg in vars(args):
logging.info('{:20}:{}'.format(arg, getattr(args, arg)))
if len(args.gpu) > 0:
gpu_id = args.gpu[rank % len(args.gpu)] if args.mix_cpu_gpu and args.num_proc > 1 else args.gpu[0]
else:
gpu_id = -1
start = time.time()
sample_time = 0
update_time = 0
......@@ -48,7 +78,7 @@ def train(args, model, train_sampler, valid_samplers=None):
args.step = step
start1 = time.time()
loss, log = model.forward(pos_g, neg_g)
loss, log = model.forward(pos_g, neg_g, gpu_id)
forward_time += time.time() - start1
start1 = time.time()
......@@ -56,7 +86,7 @@ def train(args, model, train_sampler, valid_samplers=None):
backward_time += time.time() - start1
start1 = time.time()
model.update()
model.update(gpu_id)
update_time += time.time() - start1
logs.append(log)
......@@ -80,16 +110,23 @@ def train(args, model, train_sampler, valid_samplers=None):
test(args, model, valid_samplers, mode='Valid')
print('test:', time.time() - start)
def test(args, model, test_samplers, mode='Test', queue=None):
@thread_wrapped_func
def test(args, model, test_samplers, rank=0, mode='Test', queue=None):
if args.num_proc > 1:
th.set_num_threads(1)
th.set_num_threads(4)
if len(args.gpu) > 0:
gpu_id = args.gpu[rank % len(args.gpu)] if args.mix_cpu_gpu and args.num_proc > 1 else args.gpu[0]
else:
gpu_id = -1
with th.no_grad():
logs = []
for sampler in test_samplers:
count = 0
for pos_g, neg_g in sampler:
with th.no_grad():
model.forward_test(pos_g, neg_g, logs, args.gpu)
model.forward_test(pos_g, neg_g, logs, gpu_id)
metrics = {}
if len(logs) > 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment