Unverified Commit bb6a6476 authored by xiang song(charlie.song)'s avatar xiang song(charlie.song) Committed by GitHub
Browse files

[Feature][KG] Multi-GPU training support for DGL KGE (#1178)

* multi-gpu

* Pytorch can run but test has acc problem

* pytorch train/eval can run in multi-gpu

* Fix eval

* Fix

* Fix mxnet

* trigger

* triger

* Fix mxnet score_func

* Fix

* check

* FIx default arg

* Fix train_mxnet mix_cpu_gpu

* Make relation mix_cpu_gpu

* delete some dead code

* some opt for update

* Fix cpu grad update
parent f1420d19
...@@ -21,26 +21,31 @@ def RelationPartition(edges, n): ...@@ -21,26 +21,31 @@ def RelationPartition(edges, n):
edge_cnts = np.zeros(shape=(n,), dtype=np.int64) edge_cnts = np.zeros(shape=(n,), dtype=np.int64)
rel_cnts = np.zeros(shape=(n,), dtype=np.int64) rel_cnts = np.zeros(shape=(n,), dtype=np.int64)
rel_dict = {} rel_dict = {}
rel_parts = []
for _ in range(n):
rel_parts.append([])
for i in range(len(cnts)): for i in range(len(cnts)):
cnt = cnts[i] cnt = cnts[i]
r = uniq[i] r = uniq[i]
idx = np.argmin(edge_cnts) idx = np.argmin(edge_cnts)
rel_dict[r] = idx rel_dict[r] = idx
rel_parts[idx].append(r)
edge_cnts[idx] += cnt edge_cnts[idx] += cnt
rel_cnts[idx] += 1 rel_cnts[idx] += 1
for i, edge_cnt in enumerate(edge_cnts): for i, edge_cnt in enumerate(edge_cnts):
print('part {} has {} edges and {} relations'.format(i, edge_cnt, rel_cnts[i])) print('part {} has {} edges and {} relations'.format(i, edge_cnt, rel_cnts[i]))
parts = [] parts = []
for _ in range(n): for i in range(n):
parts.append([]) parts.append([])
rel_parts[i] = np.array(rel_parts[i])
# let's store the edge index to each partition first. # let's store the edge index to each partition first.
for i, r in enumerate(rels): for i, r in enumerate(rels):
part_idx = rel_dict[r] part_idx = rel_dict[r]
parts[part_idx].append(i) parts[part_idx].append(i)
for i, part in enumerate(parts): for i, part in enumerate(parts):
parts[i] = np.array(part, dtype=np.int64) parts[i] = np.array(part, dtype=np.int64)
return parts return parts, rel_parts
def RandomPartition(edges, n): def RandomPartition(edges, n):
heads, rels, tails = edges heads, rels, tails = edges
...@@ -79,7 +84,7 @@ class TrainDataset(object): ...@@ -79,7 +84,7 @@ class TrainDataset(object):
num_train = len(triples[0]) num_train = len(triples[0])
print('|Train|:', num_train) print('|Train|:', num_train)
if ranks > 1 and args.rel_part: if ranks > 1 and args.rel_part:
self.edge_parts = RelationPartition(triples, ranks) self.edge_parts, self.rel_parts = RelationPartition(triples, ranks)
elif ranks > 1: elif ranks > 1:
self.edge_parts = RandomPartition(triples, ranks) self.edge_parts = RandomPartition(triples, ranks)
else: else:
......
...@@ -51,8 +51,8 @@ class ArgParser(argparse.ArgumentParser): ...@@ -51,8 +51,8 @@ class ArgParser(argparse.ArgumentParser):
self.add_argument('--no_eval_filter', action='store_true', self.add_argument('--no_eval_filter', action='store_true',
help='do not filter positive edges among negative edges for evaluation') help='do not filter positive edges among negative edges for evaluation')
self.add_argument('--gpu', type=int, default=-1, self.add_argument('--gpu', type=int, default=[-1], nargs='+',
help='use GPU') help='a list of active gpu ids, e.g. 0')
self.add_argument('--mix_cpu_gpu', action='store_true', self.add_argument('--mix_cpu_gpu', action='store_true',
help='mix CPU and GPU training') help='mix CPU and GPU training')
self.add_argument('-de', '--double_ent', action='store_true', self.add_argument('-de', '--double_ent', action='store_true',
...@@ -100,6 +100,7 @@ def main(args): ...@@ -100,6 +100,7 @@ def main(args):
args.train = False args.train = False
args.valid = False args.valid = False
args.test = True args.test = True
args.rel_part = False
args.batch_size_eval = args.batch_size args.batch_size_eval = args.batch_size
logger = get_logger(args) logger = get_logger(args)
...@@ -172,7 +173,7 @@ def main(args): ...@@ -172,7 +173,7 @@ def main(args):
procs = [] procs = []
for i in range(args.num_proc): for i in range(args.num_proc):
proc = mp.Process(target=test, args=(args, model, [test_sampler_heads[i], test_sampler_tails[i]], proc = mp.Process(target=test, args=(args, model, [test_sampler_heads[i], test_sampler_tails[i]],
'Test', queue)) i, 'Test', queue))
procs.append(proc) procs.append(proc)
proc.start() proc.start()
for proc in procs: for proc in procs:
......
...@@ -28,6 +28,7 @@ class KEModel(object): ...@@ -28,6 +28,7 @@ class KEModel(object):
super(KEModel, self).__init__() super(KEModel, self).__init__()
self.args = args self.args = args
self.n_entities = n_entities self.n_entities = n_entities
self.n_relations = n_relations
self.model_name = model_name self.model_name = model_name
self.hidden_dim = hidden_dim self.hidden_dim = hidden_dim
self.eps = 2.0 self.eps = 2.0
...@@ -44,7 +45,9 @@ class KEModel(object): ...@@ -44,7 +45,9 @@ class KEModel(object):
rel_dim = relation_dim * entity_dim rel_dim = relation_dim * entity_dim
else: else:
rel_dim = relation_dim rel_dim = relation_dim
self.relation_emb = ExternalEmbedding(args, n_relations, rel_dim, device)
self.rel_dim = rel_dim
self.relation_emb = ExternalEmbedding(args, n_relations, rel_dim, F.cpu() if args.mix_cpu_gpu else device)
if model_name == 'TransE' or model_name == 'TransE_l2': if model_name == 'TransE' or model_name == 'TransE_l2':
self.score_func = TransEScore(gamma, 'l2') self.score_func = TransEScore(gamma, 'l2')
...@@ -87,8 +90,8 @@ class KEModel(object): ...@@ -87,8 +90,8 @@ class KEModel(object):
def reset_parameters(self): def reset_parameters(self):
self.entity_emb.init(self.emb_init) self.entity_emb.init(self.emb_init)
self.relation_emb.init(self.emb_init)
self.score_func.reset_parameters() self.score_func.reset_parameters()
self.relation_emb.init(self.emb_init)
def predict_score(self, g): def predict_score(self, g):
self.score_func(g) self.score_func(g)
...@@ -174,8 +177,8 @@ class KEModel(object): ...@@ -174,8 +177,8 @@ class KEModel(object):
# We need to filter the positive edges in the negative graph. # We need to filter the positive edges in the negative graph.
if self.args.eval_filter: if self.args.eval_filter:
filter_bias = reshape(neg_g.edata['bias'], batch_size, -1) filter_bias = reshape(neg_g.edata['bias'], batch_size, -1)
if self.args.gpu >= 0: if gpu_id >= 0:
filter_bias = cuda(filter_bias, self.args.gpu) filter_bias = cuda(filter_bias, gpu_id)
neg_scores += filter_bias neg_scores += filter_bias
# To compute the rank of a positive edge among all negative edges, # To compute the rank of a positive edge among all negative edges,
# we need to know how many negative edges have higher scores than # we need to know how many negative edges have higher scores than
...@@ -244,7 +247,8 @@ class KEModel(object): ...@@ -244,7 +247,8 @@ class KEModel(object):
return loss, log return loss, log
def update(self): def update(self, gpu_id=-1):
self.entity_emb.update() self.entity_emb.update(gpu_id)
self.relation_emb.update() self.relation_emb.update(gpu_id)
self.score_func.update() self.score_func.update(gpu_id)
...@@ -46,7 +46,7 @@ class TransEScore(nn.Block): ...@@ -46,7 +46,7 @@ class TransEScore(nn.Block):
return head, tail return head, tail
return fn return fn
def update(self): def update(self, gpu_id=-1):
pass pass
def reset_parameters(self): def reset_parameters(self):
...@@ -171,8 +171,8 @@ class TransRScore(nn.Block): ...@@ -171,8 +171,8 @@ class TransRScore(nn.Block):
def reset_parameters(self): def reset_parameters(self):
self.projection_emb.init(1.0) self.projection_emb.init(1.0)
def update(self): def update(self, gpu_id=-1):
self.projection_emb.update() self.projection_emb.update(gpu_id)
def save(self, path, name): def save(self, path, name):
self.projection_emb.save(path, name+'projection') self.projection_emb.save(path, name+'projection')
...@@ -219,7 +219,7 @@ class DistMultScore(nn.Block): ...@@ -219,7 +219,7 @@ class DistMultScore(nn.Block):
return head, tail return head, tail
return fn return fn
def update(self): def update(self, gpu_id=-1):
pass pass
def reset_parameters(self): def reset_parameters(self):
...@@ -276,7 +276,7 @@ class ComplExScore(nn.Block): ...@@ -276,7 +276,7 @@ class ComplExScore(nn.Block):
return head, tail return head, tail
return fn return fn
def update(self): def update(self, gpu_id=-1):
pass pass
def reset_parameters(self): def reset_parameters(self):
...@@ -344,7 +344,7 @@ class RESCALScore(nn.Block): ...@@ -344,7 +344,7 @@ class RESCALScore(nn.Block):
return head, tail return head, tail
return fn return fn
def update(self): def update(self, gpu_id=-1):
pass pass
def reset_parameters(self): def reset_parameters(self):
...@@ -412,7 +412,7 @@ class RotatEScore(nn.Block): ...@@ -412,7 +412,7 @@ class RotatEScore(nn.Block):
return head, tail return head, tail
return fn return fn
def update(self): def update(self, gpu_id=-1):
pass pass
def reset_parameters(self): def reset_parameters(self):
......
...@@ -12,7 +12,7 @@ def logsigmoid(val): ...@@ -12,7 +12,7 @@ def logsigmoid(val):
z = nd.exp(-max_elem) + nd.exp(-val - max_elem) z = nd.exp(-max_elem) + nd.exp(-val - max_elem)
return -(max_elem + nd.log(z)) return -(max_elem + nd.log(z))
get_device = lambda args : mx.gpu(args.gpu) if args.gpu >= 0 else mx.cpu() get_device = lambda args : mx.gpu(args.gpu[0]) if args.gpu[0] >= 0 else mx.cpu()
norm = lambda x, p: nd.sum(nd.abs(x) ** p) norm = lambda x, p: nd.sum(nd.abs(x) ** p)
get_scalar = lambda x: x.detach().asscalar() get_scalar = lambda x: x.detach().asscalar()
...@@ -44,14 +44,14 @@ class ExternalEmbedding: ...@@ -44,14 +44,14 @@ class ExternalEmbedding:
if self.emb.context != idx.context: if self.emb.context != idx.context:
idx = idx.as_in_context(self.emb.context) idx = idx.as_in_context(self.emb.context)
data = nd.take(self.emb, idx) data = nd.take(self.emb, idx)
if self.gpu >= 0: if gpu_id >= 0:
data = data.as_in_context(mx.gpu(self.gpu)) data = data.as_in_context(mx.gpu(gpu_id))
data.attach_grad() data.attach_grad()
if trace: if trace:
self.trace.append((idx, data)) self.trace.append((idx, data))
return data return data
def update(self): def update(self, gpu_id=-1):
self.state_step += 1 self.state_step += 1
for idx, data in self.trace: for idx, data in self.trace:
grad = data.grad grad = data.grad
...@@ -71,9 +71,9 @@ class ExternalEmbedding: ...@@ -71,9 +71,9 @@ class ExternalEmbedding:
grad_sum = grad_sum.as_in_context(ctx) grad_sum = grad_sum.as_in_context(ctx)
self.state_sum[grad_indices] += grad_sum self.state_sum[grad_indices] += grad_sum
std = self.state_sum[grad_indices] # _sparse_mask std = self.state_sum[grad_indices] # _sparse_mask
if gpu_id >= 0:
std = std.as_in_context(mx.gpu(gpu_id))
std_values = nd.expand_dims(nd.sqrt(std) + 1e-10, 1) std_values = nd.expand_dims(nd.sqrt(std) + 1e-10, 1)
if self.gpu >= 0:
std_values = std_values.as_in_context(mx.gpu(self.args.gpu))
tmp = (-clr * grad_values / std_values) tmp = (-clr * grad_values / std_values)
if tmp.context != ctx: if tmp.context != ctx:
tmp = tmp.as_in_context(ctx) tmp = tmp.as_in_context(ctx)
......
...@@ -47,7 +47,7 @@ class TransEScore(nn.Module): ...@@ -47,7 +47,7 @@ class TransEScore(nn.Module):
def forward(self, g): def forward(self, g):
g.apply_edges(lambda edges: self.edge_func(edges)) g.apply_edges(lambda edges: self.edge_func(edges))
def update(self): def update(self, gpu_id=-1):
pass pass
def reset_parameters(self): def reset_parameters(self):
...@@ -138,8 +138,8 @@ class TransRScore(nn.Module): ...@@ -138,8 +138,8 @@ class TransRScore(nn.Module):
def reset_parameters(self): def reset_parameters(self):
self.projection_emb.init(1.0) self.projection_emb.init(1.0)
def update(self): def update(self, gpu_id=-1):
self.projection_emb.update() self.projection_emb.update(gpu_id)
def save(self, path, name): def save(self, path, name):
self.projection_emb.save(path, name+'projection') self.projection_emb.save(path, name+'projection')
...@@ -186,7 +186,7 @@ class DistMultScore(nn.Module): ...@@ -186,7 +186,7 @@ class DistMultScore(nn.Module):
return head, tail return head, tail
return fn return fn
def update(self): def update(self, gpu_id=-1):
pass pass
def reset_parameters(self): def reset_parameters(self):
...@@ -243,7 +243,7 @@ class ComplExScore(nn.Module): ...@@ -243,7 +243,7 @@ class ComplExScore(nn.Module):
return head, tail return head, tail
return fn return fn
def update(self): def update(self, gpu_id=-1):
pass pass
def reset_parameters(self): def reset_parameters(self):
...@@ -314,7 +314,7 @@ class RESCALScore(nn.Module): ...@@ -314,7 +314,7 @@ class RESCALScore(nn.Module):
return head, tail return head, tail
return fn return fn
def update(self): def update(self, gpu_id=-1):
pass pass
def reset_parameters(self): def reset_parameters(self):
...@@ -373,7 +373,7 @@ class RotatEScore(nn.Module): ...@@ -373,7 +373,7 @@ class RotatEScore(nn.Module):
score = score.norm(dim=0) score = score.norm(dim=0)
return {'score': self.gamma - score.sum(-1)} return {'score': self.gamma - score.sum(-1)}
def update(self): def update(self, gpu_id=-1):
pass pass
def reset_parameters(self): def reset_parameters(self):
......
...@@ -23,7 +23,7 @@ from .. import * ...@@ -23,7 +23,7 @@ from .. import *
logsigmoid = functional.logsigmoid logsigmoid = functional.logsigmoid
def get_device(args): def get_device(args):
return th.device('cpu') if args.gpu < 0 else th.device('cuda:' + str(args.gpu)) return th.device('cpu') if args.gpu[0] < 0 else th.device('cuda:' + str(args.gpu[0]))
norm = lambda x, p: x.norm(p=p)**p norm = lambda x, p: x.norm(p=p)**p
...@@ -53,8 +53,8 @@ class ExternalEmbedding: ...@@ -53,8 +53,8 @@ class ExternalEmbedding:
def __call__(self, idx, gpu_id=-1, trace=True): def __call__(self, idx, gpu_id=-1, trace=True):
s = self.emb[idx] s = self.emb[idx]
if self.gpu >= 0: if gpu_id >= 0:
s = s.cuda(self.gpu) s = s.cuda(gpu_id)
# During the training, we need to trace the computation. # During the training, we need to trace the computation.
# In this case, we need to record the computation path and compute the gradients. # In this case, we need to record the computation path and compute the gradients.
if trace: if trace:
...@@ -64,7 +64,7 @@ class ExternalEmbedding: ...@@ -64,7 +64,7 @@ class ExternalEmbedding:
data = s data = s
return data return data
def update(self): def update(self, gpu_id=-1):
self.state_step += 1 self.state_step += 1
with th.no_grad(): with th.no_grad():
for idx, data in self.trace: for idx, data in self.trace:
...@@ -85,9 +85,9 @@ class ExternalEmbedding: ...@@ -85,9 +85,9 @@ class ExternalEmbedding:
grad_sum = grad_sum.to(device) grad_sum = grad_sum.to(device)
self.state_sum.index_add_(0, grad_indices, grad_sum) self.state_sum.index_add_(0, grad_indices, grad_sum)
std = self.state_sum[grad_indices] # _sparse_mask std = self.state_sum[grad_indices] # _sparse_mask
if gpu_id >= 0:
std = std.cuda(gpu_id)
std_values = std.sqrt_().add_(1e-10).unsqueeze(1) std_values = std.sqrt_().add_(1e-10).unsqueeze(1)
if self.gpu >= 0:
std_values = std_values.cuda(self.args.gpu)
tmp = (-clr * grad_values / std_values) tmp = (-clr * grad_values / std_values)
if tmp.device != device: if tmp.device != device:
tmp = tmp.to(device) tmp = tmp.to(device)
......
...@@ -72,8 +72,8 @@ class ArgParser(argparse.ArgumentParser): ...@@ -72,8 +72,8 @@ class ArgParser(argparse.ArgumentParser):
self.add_argument('--no_eval_filter', action='store_true', self.add_argument('--no_eval_filter', action='store_true',
help='do not filter positive edges among negative edges for evaluation') help='do not filter positive edges among negative edges for evaluation')
self.add_argument('--gpu', type=int, default=-1, self.add_argument('--gpu', type=int, default=[-1], nargs='+',
help='use GPU') help='a list of active gpu ids, e.g. 0 1 2 4')
self.add_argument('--mix_cpu_gpu', action='store_true', self.add_argument('--mix_cpu_gpu', action='store_true',
help='mix CPU and GPU training') help='mix CPU and GPU training')
self.add_argument('-de', '--double_ent', action='store_true', self.add_argument('-de', '--double_ent', action='store_true',
...@@ -298,8 +298,9 @@ def run(args, logger): ...@@ -298,8 +298,9 @@ def run(args, logger):
if args.num_proc > 1: if args.num_proc > 1:
procs = [] procs = []
for i in range(args.num_proc): for i in range(args.num_proc):
rel_parts = train_data.rel_parts if args.rel_part else None
valid_samplers = [valid_sampler_heads[i], valid_sampler_tails[i]] if args.valid else None valid_samplers = [valid_sampler_heads[i], valid_sampler_tails[i]] if args.valid else None
proc = mp.Process(target=train, args=(args, model, train_samplers[i], valid_samplers)) proc = mp.Process(target=train, args=(args, model, train_samplers[i], i, rel_parts, valid_samplers))
procs.append(proc) procs.append(proc)
proc.start() proc.start()
for proc in procs: for proc in procs:
...@@ -322,7 +323,7 @@ def run(args, logger): ...@@ -322,7 +323,7 @@ def run(args, logger):
procs = [] procs = []
for i in range(args.num_proc): for i in range(args.num_proc):
proc = mp.Process(target=test, args=(args, model, [test_sampler_heads[i], test_sampler_tails[i]], proc = mp.Process(target=test, args=(args, model, [test_sampler_heads[i], test_sampler_tails[i]],
'Test', queue)) i, 'Test', queue))
procs.append(proc) procs.append(proc)
proc.start() proc.start()
......
...@@ -14,11 +14,7 @@ def load_model(logger, args, n_entities, n_relations, ckpt=None): ...@@ -14,11 +14,7 @@ def load_model(logger, args, n_entities, n_relations, ckpt=None):
args.hidden_dim, args.gamma, args.hidden_dim, args.gamma,
double_entity_emb=args.double_ent, double_relation_emb=args.double_rel) double_entity_emb=args.double_ent, double_relation_emb=args.double_rel)
if ckpt is not None: if ckpt is not None:
# TODO: loading model emb only work for genernal Embedding, not for ExternalEmbedding assert False, "We do not support loading model emb for genernal Embedding"
if args.gpu >= 0:
model.load_parameters(ckpt, ctx=mx.gpu(args.gpu))
else:
model.load_parameters(ckpt, ctx=mx.cpu())
logger.info('Load model {}'.format(args.model_name)) logger.info('Load model {}'.format(args.model_name))
return model return model
...@@ -28,23 +24,28 @@ def load_model_from_checkpoint(logger, args, n_entities, n_relations, ckpt_path) ...@@ -28,23 +24,28 @@ def load_model_from_checkpoint(logger, args, n_entities, n_relations, ckpt_path)
model.load_emb(ckpt_path, args.dataset) model.load_emb(ckpt_path, args.dataset)
return model return model
def train(args, model, train_sampler, valid_samplers=None): def train(args, model, train_sampler, rank=0, rel_parts=None, valid_samplers=None):
if args.num_proc > 1: assert args.num_proc == 1, "MXNet KGE does not support multi-process now"
os.environ['OMP_NUM_THREADS'] = '1' assert args.rel_part == False, "No need for relation partition in single process for MXNet KGE"
logs = [] logs = []
for arg in vars(args): for arg in vars(args):
logging.info('{:20}:{}'.format(arg, getattr(args, arg))) logging.info('{:20}:{}'.format(arg, getattr(args, arg)))
if len(args.gpu) > 0:
gpu_id = args.gpu[rank % len(args.gpu)] if args.mix_cpu_gpu and args.num_proc > 1 else args.gpu[0]
else:
gpu_id = -1
start = time.time() start = time.time()
for step in range(args.init_step, args.max_step): for step in range(args.init_step, args.max_step):
pos_g, neg_g = next(train_sampler) pos_g, neg_g = next(train_sampler)
args.step = step args.step = step
with mx.autograd.record(): with mx.autograd.record():
loss, log = model.forward(pos_g, neg_g, args.gpu) loss, log = model.forward(pos_g, neg_g, gpu_id)
loss.backward() loss.backward()
logs.append(log) logs.append(log)
model.update() model.update(gpu_id)
if step % args.log_interval == 0: if step % args.log_interval == 0:
for k in logs[0].keys(): for k in logs[0].keys():
...@@ -61,14 +62,20 @@ def train(args, model, train_sampler, valid_samplers=None): ...@@ -61,14 +62,20 @@ def train(args, model, train_sampler, valid_samplers=None):
# clear cache # clear cache
logs = [] logs = []
def test(args, model, test_samplers, mode='Test', queue=None): def test(args, model, test_samplers, rank=0, mode='Test', queue=None):
assert args.num_proc == 1, "MXNet KGE does not support multi-process now"
logs = [] logs = []
if len(args.gpu) > 0:
gpu_id = args.gpu[rank % len(args.gpu)] if args.mix_cpu_gpu and args.num_proc > 1 else args.gpu[0]
else:
gpu_id = -1
for sampler in test_samplers: for sampler in test_samplers:
#print('Number of tests: ' + len(sampler)) #print('Number of tests: ' + len(sampler))
count = 0 count = 0
for pos_g, neg_g in sampler: for pos_g, neg_g in sampler:
model.forward_test(pos_g, neg_g, logs, args.gpu) model.forward_test(pos_g, neg_g, logs, gpu_id)
metrics = {} metrics = {}
if len(logs) > 0: if len(logs) > 0:
......
...@@ -4,6 +4,8 @@ from torch.utils.data import DataLoader ...@@ -4,6 +4,8 @@ from torch.utils.data import DataLoader
import torch.optim as optim import torch.optim as optim
import torch as th import torch as th
import torch.multiprocessing as mp import torch.multiprocessing as mp
from torch.multiprocessing import Queue
from _thread import start_new_thread
from distutils.version import LooseVersion from distutils.version import LooseVersion
TH_VERSION = LooseVersion(th.__version__) TH_VERSION = LooseVersion(th.__version__)
...@@ -13,14 +15,36 @@ if TH_VERSION.version[0] == 1 and TH_VERSION.version[1] < 2: ...@@ -13,14 +15,36 @@ if TH_VERSION.version[0] == 1 and TH_VERSION.version[1] < 2:
import os import os
import logging import logging
import time import time
from functools import wraps
def thread_wrapped_func(func):
@wraps(func)
def decorated_function(*args, **kwargs):
queue = Queue()
def _queue_result():
exception, trace, res = None, None, None
try:
res = func(*args, **kwargs)
except Exception as e:
exception = e
trace = traceback.format_exc()
queue.put((res, exception, trace))
start_new_thread(_queue_result, ())
result, exception, trace = queue.get()
if exception is None:
return result
else:
assert isinstance(exception, Exception)
raise exception.__class__(trace)
return decorated_function
def load_model(logger, args, n_entities, n_relations, ckpt=None): def load_model(logger, args, n_entities, n_relations, ckpt=None):
model = KEModel(args, args.model_name, n_entities, n_relations, model = KEModel(args, args.model_name, n_entities, n_relations,
args.hidden_dim, args.gamma, args.hidden_dim, args.gamma,
double_entity_emb=args.double_ent, double_relation_emb=args.double_rel) double_entity_emb=args.double_ent, double_relation_emb=args.double_rel)
if ckpt is not None: if ckpt is not None:
# TODO: loading model emb only work for genernal Embedding, not for ExternalEmbedding assert False, "We do not support loading model emb for genernal Embedding"
model.load_state_dict(ckpt['model_state_dict'])
return model return model
...@@ -29,13 +53,19 @@ def load_model_from_checkpoint(logger, args, n_entities, n_relations, ckpt_path) ...@@ -29,13 +53,19 @@ def load_model_from_checkpoint(logger, args, n_entities, n_relations, ckpt_path)
model.load_emb(ckpt_path, args.dataset) model.load_emb(ckpt_path, args.dataset)
return model return model
def train(args, model, train_sampler, valid_samplers=None): @thread_wrapped_func
def train(args, model, train_sampler, rank=0, rel_parts=None, valid_samplers=None):
if args.num_proc > 1: if args.num_proc > 1:
th.set_num_threads(1) th.set_num_threads(4)
logs = [] logs = []
for arg in vars(args): for arg in vars(args):
logging.info('{:20}:{}'.format(arg, getattr(args, arg))) logging.info('{:20}:{}'.format(arg, getattr(args, arg)))
if len(args.gpu) > 0:
gpu_id = args.gpu[rank % len(args.gpu)] if args.mix_cpu_gpu and args.num_proc > 1 else args.gpu[0]
else:
gpu_id = -1
start = time.time() start = time.time()
sample_time = 0 sample_time = 0
update_time = 0 update_time = 0
...@@ -48,7 +78,7 @@ def train(args, model, train_sampler, valid_samplers=None): ...@@ -48,7 +78,7 @@ def train(args, model, train_sampler, valid_samplers=None):
args.step = step args.step = step
start1 = time.time() start1 = time.time()
loss, log = model.forward(pos_g, neg_g) loss, log = model.forward(pos_g, neg_g, gpu_id)
forward_time += time.time() - start1 forward_time += time.time() - start1
start1 = time.time() start1 = time.time()
...@@ -56,7 +86,7 @@ def train(args, model, train_sampler, valid_samplers=None): ...@@ -56,7 +86,7 @@ def train(args, model, train_sampler, valid_samplers=None):
backward_time += time.time() - start1 backward_time += time.time() - start1
start1 = time.time() start1 = time.time()
model.update() model.update(gpu_id)
update_time += time.time() - start1 update_time += time.time() - start1
logs.append(log) logs.append(log)
...@@ -80,16 +110,23 @@ def train(args, model, train_sampler, valid_samplers=None): ...@@ -80,16 +110,23 @@ def train(args, model, train_sampler, valid_samplers=None):
test(args, model, valid_samplers, mode='Valid') test(args, model, valid_samplers, mode='Valid')
print('test:', time.time() - start) print('test:', time.time() - start)
def test(args, model, test_samplers, mode='Test', queue=None): @thread_wrapped_func
def test(args, model, test_samplers, rank=0, mode='Test', queue=None):
if args.num_proc > 1: if args.num_proc > 1:
th.set_num_threads(1) th.set_num_threads(4)
if len(args.gpu) > 0:
gpu_id = args.gpu[rank % len(args.gpu)] if args.mix_cpu_gpu and args.num_proc > 1 else args.gpu[0]
else:
gpu_id = -1
with th.no_grad(): with th.no_grad():
logs = [] logs = []
for sampler in test_samplers: for sampler in test_samplers:
count = 0 count = 0
for pos_g, neg_g in sampler: for pos_g, neg_g in sampler:
with th.no_grad(): with th.no_grad():
model.forward_test(pos_g, neg_g, logs, args.gpu) model.forward_test(pos_g, neg_g, logs, gpu_id)
metrics = {} metrics = {}
if len(logs) > 0: if len(logs) > 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment