Unverified Commit 00ba4094 authored by Chao Ma's avatar Chao Ma Committed by GitHub
Browse files

[DGL-KE] Distributed training of DGL-KE (#1290)

* update

* change name

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* change worker number

* update

* update

* update

* update

* update

* update

* test

* update

* update

* update

* remove barrier

* max_step

* update

* add complex

* update

* chmod +x

* update

* update

* random partition

* random partition

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* change num_test_proc

* update num_thread

* update
parent c3a33407
......@@ -37,7 +37,7 @@ class KGDataset1:
The triples are stored as 'head_name\trelation_name\ttail_name'.
'''
def __init__(self, path, name):
def __init__(self, path, name, read_triple=True, only_train=False):
url = 'https://s3.us-east-2.amazonaws.com/dgl.ai/dataset/{}.zip'.format(name)
if not os.path.exists(os.path.join(path, name)):
......@@ -66,9 +66,11 @@ class KGDataset1:
self.n_entities = len(self.entity2id)
self.n_relations = len(self.relation2id)
self.train = self.read_triple(path, 'train')
self.valid = self.read_triple(path, 'valid')
self.test = self.read_triple(path, 'test')
if read_triple == True:
self.train = self.read_triple(path, 'train')
if only_train == False:
self.valid = self.read_triple(path, 'valid')
self.test = self.read_triple(path, 'test')
def read_triple(self, path, mode):
# mode: train/valid/test
......@@ -102,7 +104,7 @@ class KGDataset2:
The triples are stored as 'head_nid\trelation_id\ttail_nid'.
'''
def __init__(self, path, name):
def __init__(self, path, name, read_triple=True, only_train=False):
url = 'https://s3.us-east-2.amazonaws.com/dgl.ai/dataset/{}.zip'.format(name)
if not os.path.exists(os.path.join(path, name)):
......@@ -110,17 +112,24 @@ class KGDataset2:
_download_and_extract(url, path, '{}.zip'.format(name))
self.path = os.path.join(path, name)
f_ent2id = os.path.join(self.path, 'entity2id.txt')
f_rel2id = os.path.join(self.path, 'relation2id.txt')
with open(f_ent2id) as f_ent:
self.n_entities = int(f_ent.readline()[:-1])
with open(f_rel2id) as f_rel:
self.n_relations = int(f_rel.readline()[:-1])
self.train = self.read_triple(self.path, 'train')
self.valid = self.read_triple(self.path, 'valid')
self.test = self.read_triple(self.path, 'test')
if only_train == True:
f_ent2id = os.path.join(self.path, 'local_to_global.txt')
with open(f_ent2id) as f_ent:
self.n_entities = len(f_ent.readlines())
else:
f_ent2id = os.path.join(self.path, 'entity2id.txt')
with open(f_ent2id) as f_ent:
self.n_entities = int(f_ent.readline()[:-1])
if read_triple == True:
self.train = self.read_triple(self.path, 'train')
if only_train == False:
self.valid = self.read_triple(self.path, 'valid')
self.test = self.read_triple(self.path, 'test')
def read_triple(self, path, mode, skip_first_line=False):
heads = []
......@@ -151,3 +160,57 @@ def get_dataset(data_path, data_name, format_str):
dataset = KGDataset2(data_path, data_name)
return dataset
def get_partition_dataset(data_path, data_name, format_str, part_id):
part_name = os.path.join(data_name, 'part_'+str(part_id))
if data_name == 'Freebase':
dataset = KGDataset2(data_path, part_name, read_triple=True, only_train=True)
elif format_str == '1':
dataset = KGDataset1(data_path, part_name, read_triple=True, only_train=True)
else:
dataset = KGDataset2(data_path, part_name, read_triple=True, only_train=True)
path = os.path.join(data_path, part_name)
partition_book = []
with open(os.path.join(path, 'partition_book.txt')) as f:
for line in f:
partition_book.append(int(line))
local_to_global = []
with open(os.path.join(path, 'local_to_global.txt')) as f:
for line in f:
local_to_global.append(int(line))
return dataset, partition_book, local_to_global
def get_server_partition_dataset(data_path, data_name, format_str, part_id):
part_name = os.path.join(data_name, 'part_'+str(part_id))
if data_name == 'Freebase':
dataset = KGDataset2(data_path, part_name, read_triple=False, only_train=True)
elif format_str == '1':
dataset = KGDataset1(data_path, part_name, read_triple=False, only_train=True)
else:
dataset = KGDataset2(data_path, part_name, read_triple=False, only_train=True)
path = os.path.join(data_path, part_name)
n_entities = len(open(os.path.join(path, 'partition_book.txt')).readlines())
local_to_global = []
with open(os.path.join(path, 'local_to_global.txt')) as f:
for line in f:
local_to_global.append(int(line))
global_to_local = [0] * n_entities
for i in range(len(local_to_global)):
global_id = local_to_global[i]
global_to_local[global_id] = i
local_to_global = None
return global_to_local, dataset
##################################################################################
# This script runing ComplEx model on Freebase dataset in distributed setting.
# You can change the hyper-parameter in this file but DO NOT run script manually
##################################################################################
machine_id=$1
server_count=$2
# Delete the temp file
rm *-shape
##################################################################################
# Start kvserver
##################################################################################
SERVER_ID_LOW=$((machine_id*server_count))
SERVER_ID_HIGH=$(((machine_id+1)*server_count))
while [ $SERVER_ID_LOW -lt $SERVER_ID_HIGH ]
do
MKL_NUM_THREADS=1 OMP_NUM_THREADS=1 DGLBACKEND=pytorch python3 ../kvserver.py --model ComplEx --dataset Freebase \
--hidden_dim 400 --gamma 143.0 --lr 0.1 --total_client 160 --server_id $SERVER_ID_LOW &
let SERVER_ID_LOW+=1
done
##################################################################################
# Start kvclient
##################################################################################
MKL_NUM_THREADS=1 OMP_NUM_THREADS=1 DGLBACKEND=pytorch python3 ../kvclient.py --model ComplEx --dataset Freebase \
--batch_size 1024 --neg_sample_size 256 --hidden_dim 400 --gamma 143.0 --lr 0.1 --max_step 12500 --log_interval 100 \
--batch_size_eval 1000 --neg_sample_size_test 1000 --test -adv --total_machine 4 --num_client 40
\ No newline at end of file
##################################################################################
# This script runing distmult model on Freebase dataset in distributed setting.
# You can change the hyper-parameter in this file but DO NOT run script manually
##################################################################################
machine_id=$1
server_count=$2
# Delete the temp file
rm *-shape
##################################################################################
# Start kvserver
##################################################################################
SERVER_ID_LOW=$((machine_id*server_count))
SERVER_ID_HIGH=$(((machine_id+1)*server_count))
while [ $SERVER_ID_LOW -lt $SERVER_ID_HIGH ]
do
MKL_NUM_THREADS=1 OMP_NUM_THREADS=1 DGLBACKEND=pytorch python3 ../kvserver.py --model DistMult --dataset Freebase \
--hidden_dim 400 --gamma 143.0 --lr 0.08 --total_client 160 --server_id $SERVER_ID_LOW &
let SERVER_ID_LOW+=1
done
##################################################################################
# Start kvclient
##################################################################################
MKL_NUM_THREADS=1 OMP_NUM_THREADS=1 DGLBACKEND=pytorch python3 ../kvclient.py --model DistMult --dataset Freebase \
--batch_size 1024 --neg_sample_size 256 --hidden_dim 400 --gamma 143.0 --lr 0.08 --max_step 12500 --log_interval 100 \
--batch_size_eval 1000 --neg_sample_size_test 1000 --test -adv --total_machine 4 --num_client 40
\ No newline at end of file
##################################################################################
# This script runing distmult model on Freebase dataset in distributed setting.
# You can change the hyper-parameter in this file but DO NOT run script manually
##################################################################################
machine_id=$1
server_count=$2
# Delete the temp file
rm *-shape
##################################################################################
# Start kvserver
##################################################################################
SERVER_ID_LOW=$((machine_id*server_count))
SERVER_ID_HIGH=$(((machine_id+1)*server_count))
while [ $SERVER_ID_LOW -lt $SERVER_ID_HIGH ]
do
MKL_NUM_THREADS=1 OMP_NUM_THREADS=1 DGLBACKEND=pytorch python3 ../kvserver.py --model TransE_l2 --dataset Freebase \
--hidden_dim 400 --gamma 10 --lr 0.1 --total_client 160 --server_id $SERVER_ID_LOW &
let SERVER_ID_LOW+=1
done
##################################################################################
# Start kvclient
##################################################################################
MKL_NUM_THREADS=1 OMP_NUM_THREADS=1 DGLBACKEND=pytorch python3 ../kvclient.py --model TransE_l2 --dataset Freebase \
--batch_size 1000 --neg_sample_size 200 --hidden_dim 400 --gamma 10 --lr 0.1 --max_step 12500 --log_interval 100 \
--batch_size_eval 1000 --neg_sample_size_test 1000 --test -adv --regularization_coef 1e-9 --total_machine 4 --num_client 40
\ No newline at end of file
127.0.0.1 30050 8
127.0.0.1 30050 8
127.0.0.1 30050 8
127.0.0.1 30050 8
\ No newline at end of file
##################################################################################
# User runs this script to launch distrobited jobs on cluster
##################################################################################
script_path=~/dgl/apps/kg/distributed
script_file=./freebase_transe_l2.sh
user_name=ubuntu
ssh_key=~/mctt.pem
server_count=$(awk 'NR==1 {print $3}' ip_config.txt)
# run command on remote machine
LINE_LOW=2
LINE_HIGH=$(awk 'END{print NR}' ip_config.txt)
let LINE_HIGH+=1
s_id=0
while [ $LINE_LOW -lt $LINE_HIGH ]
do
ip=$(awk 'NR=='$LINE_LOW' {print $1}' ip_config.txt)
let LINE_LOW+=1
let s_id+=1
ssh -i $ssh_key $user_name@$ip 'cd '$script_path'; '$script_file' '$s_id' '$server_count' ' &
done
# run command on local machine
$script_file 0 $server_count
\ No newline at end of file
......@@ -66,6 +66,8 @@ class ArgParser(argparse.ArgumentParser):
help='number of workers used for loading data')
self.add_argument('--num_proc', type=int, default=1,
help='number of process used')
self.add_argument('--num_thread', type=int, default=1,
help='number of thread used')
def parse_args(self):
args = super().parse_args()
......
import os
import argparse
import time
import logging
import socket
if os.name != 'nt':
import fcntl
import struct
import torch.multiprocessing as mp
from train_pytorch import load_model, dist_train_test
from train import get_logger
from dataloader import TrainDataset, NewBidirectionalOneShotIterator
from dataloader import get_dataset, get_partition_dataset
import dgl
import dgl.backend as F
NUM_THREAD = 1 # Fix the number of threads to 1 on kvclient
NUM_WORKER = 1 # Fix the number of worker for sampler to 1
class ArgParser(argparse.ArgumentParser):
def __init__(self):
super(ArgParser, self).__init__()
self.add_argument('--model_name', default='TransE',
choices=['TransE', 'TransE_l1', 'TransE_l2', 'TransR',
'RESCAL', 'DistMult', 'ComplEx', 'RotatE'],
help='model to use')
self.add_argument('--data_path', type=str, default='../data',
help='root path of all dataset')
self.add_argument('--dataset', type=str, default='FB15k',
help='dataset name, under data_path')
self.add_argument('--format', type=str, default='1',
help='the format of the dataset.')
self.add_argument('--save_path', type=str, default='ckpts',
help='place to save models and logs')
self.add_argument('--save_emb', type=str, default=None,
help='save the embeddings in the specific location.')
self.add_argument('--max_step', type=int, default=80000,
help='train xx steps')
self.add_argument('--warm_up_step', type=int, default=None,
help='for learning rate decay')
self.add_argument('--batch_size', type=int, default=1024,
help='batch size')
self.add_argument('--batch_size_eval', type=int, default=8,
help='batch size used for eval and test')
self.add_argument('--neg_sample_size', type=int, default=128,
help='negative sampling size')
self.add_argument('--neg_chunk_size', type=int, default=-1,
help='chunk size of the negative edges.')
self.add_argument('--neg_deg_sample', action='store_true',
help='negative sample proportional to vertex degree in the training')
self.add_argument('--neg_deg_sample_eval', action='store_true',
help='negative sampling proportional to vertex degree in the evaluation')
self.add_argument('--neg_sample_size_valid', type=int, default=1000,
help='negative sampling size for validation')
self.add_argument('--neg_chunk_size_valid', type=int, default=-1,
help='chunk size of the negative edges.')
self.add_argument('--neg_sample_size_test', type=int, default=-1,
help='negative sampling size for testing')
self.add_argument('--neg_chunk_size_test', type=int, default=-1,
help='chunk size of the negative edges.')
self.add_argument('--hidden_dim', type=int, default=256,
help='hidden dim used by relation and entity')
self.add_argument('--lr', type=float, default=0.0001,
help='learning rate')
self.add_argument('-g', '--gamma', type=float, default=12.0,
help='margin value')
self.add_argument('--eval_percent', type=float, default=1,
help='sample some percentage for evaluation.')
self.add_argument('--no_eval_filter', action='store_true',
help='do not filter positive edges among negative edges for evaluation')
self.add_argument('--gpu', type=int, default=[-1], nargs='+',
help='a list of active gpu ids, e.g. 0 1 2 4')
self.add_argument('--mix_cpu_gpu', action='store_true',
help='mix CPU and GPU training')
self.add_argument('-de', '--double_ent', action='store_true',
help='double entitiy dim for complex number')
self.add_argument('-dr', '--double_rel', action='store_true',
help='double relation dim for complex number')
self.add_argument('--seed', type=int, default=0,
help='set random seed fro reproducibility')
self.add_argument('-log', '--log_interval', type=int, default=1000,
help='do evaluation after every x steps')
self.add_argument('--eval_interval', type=int, default=10000,
help='do evaluation after every x steps')
self.add_argument('-adv', '--neg_adversarial_sampling', action='store_true',
help='if use negative adversarial sampling')
self.add_argument('-a', '--adversarial_temperature', default=1.0, type=float)
self.add_argument('--valid', action='store_true',
help='if valid a model')
self.add_argument('--test', action='store_true',
help='if test a model')
self.add_argument('-rc', '--regularization_coef', type=float, default=0.000002,
help='set value > 0.0 if regularization is used')
self.add_argument('-rn', '--regularization_norm', type=int, default=3,
help='norm used in regularization')
self.add_argument('--num_worker', type=int, default=32,
help='number of workers used for loading data')
self.add_argument('--non_uni_weight', action='store_true',
help='if use uniform weight when computing loss')
self.add_argument('--init_step', type=int, default=0,
help='DONT SET MANUALLY, used for resume')
self.add_argument('--step', type=int, default=0,
help='DONT SET MANUALLY, track current step')
self.add_argument('--pickle_graph', action='store_true',
help='pickle built graph, building a huge graph is slow.')
self.add_argument('--num_proc', type=int, default=1,
help='number of process used')
self.add_argument('--rel_part', action='store_true',
help='enable relation partitioning')
self.add_argument('--soft_rel_part', action='store_true',
help='enable soft relation partition')
self.add_argument('--nomp_thread_per_process', type=int, default=-1,
help='num of omp threads used per process in multi-process training')
self.add_argument('--async_update', action='store_true',
help='allow async_update on node embedding')
self.add_argument('--force_sync_interval', type=int, default=-1,
help='We force a synchronization between processes every x steps')
self.add_argument('--strict_rel_part', action='store_true',
help='Strict relation partition')
self.add_argument('--machine_id', type=int, default=0,
help='Unique ID of current machine.')
self.add_argument('--total_machine', type=int, default=1,
help='Total number of machine.')
self.add_argument('--ip_config', type=str, default='ip_config.txt',
help='IP configuration file of kvstore')
self.add_argument('--num_client', type=int, default=1,
help='Number of client on each machine.')
def get_long_tail_partition(n_relations, n_machine):
"""Relation types has a long tail distribution for many dataset.
So we need to average shuffle the data before we partition it.
"""
assert n_relations > 0, 'n_relations must be a positive number.'
assert n_machine > 0, 'n_machine must be a positive number.'
partition_book = [0] * n_relations
part_id = 0
for i in range(n_relations):
partition_book[i] = part_id
part_id += 1
if part_id == n_machine:
part_id = 0
return partition_book
def local_ip4_addr_list():
"""Return a set of IPv4 address
"""
nic = set()
for ix in socket.if_nameindex():
name = ix[1]
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
ip = socket.inet_ntoa(fcntl.ioctl(
s.fileno(),
0x8915, # SIOCGIFADDR
struct.pack('256s', name[:15].encode("UTF-8")))[20:24])
nic.add(ip)
return nic
def get_local_machine_id(server_namebook):
"""Get machine ID via server_namebook
"""
assert len(server_namebook) > 0, 'server_namebook cannot be empty.'
res = 0
for ID, data in server_namebook.items():
machine_id = data[0]
ip = data[1]
if ip in local_ip4_addr_list():
res = machine_id
break
return res
def start_worker(args, logger):
"""Start kvclient for training
"""
train_time_start = time.time()
server_namebook = dgl.contrib.read_ip_config(filename=args.ip_config)
args.machine_id = get_local_machine_id(server_namebook)
dataset, entity_partition_book, local2global = get_partition_dataset(
args.data_path,
args.dataset,
args.format,
args.machine_id)
n_entities = dataset.n_entities
n_relations = dataset.n_relations
print('Partition %d n_entities: %d' % (args.machine_id, n_entities))
print("Partition %d n_relations: %d" % (args.machine_id, n_relations))
entity_partition_book = F.tensor(entity_partition_book)
relation_partition_book = get_long_tail_partition(dataset.n_relations, args.total_machine)
relation_partition_book = F.tensor(relation_partition_book)
local2global = F.tensor(local2global)
relation_partition_book.share_memory_()
entity_partition_book.share_memory_()
local2global.share_memory_()
model = load_model(logger, args, n_entities, n_relations)
model.share_memory()
# When we generate a batch of negative edges from a set of positive edges,
# we first divide the positive edges into chunks and corrupt the edges in a chunk
# together. By default, the chunk size is equal to the negative sample size.
# Usually, this works well. But we also allow users to specify the chunk size themselves.
if args.neg_chunk_size < 0:
args.neg_chunk_size = args.neg_sample_size
num_workers = NUM_WORKER
train_data = TrainDataset(dataset, args, ranks=args.num_client)
train_samplers = []
for i in range(args.num_client):
train_sampler_head = train_data.create_sampler(args.batch_size,
args.neg_sample_size,
args.neg_chunk_size,
mode='head',
num_workers=num_workers,
shuffle=True,
exclude_positive=False,
rank=i)
train_sampler_tail = train_data.create_sampler(args.batch_size,
args.neg_sample_size,
args.neg_chunk_size,
mode='tail',
num_workers=num_workers,
shuffle=True,
exclude_positive=False,
rank=i)
train_samplers.append(NewBidirectionalOneShotIterator(train_sampler_head, train_sampler_tail,
args.neg_chunk_size, args.neg_sample_size,
True, n_entities))
dataset = None
print('Total data loading time {:.3f} seconds'.format(time.time() - train_time_start))
rel_parts = train_data.rel_parts if args.strict_rel_part or args.soft_rel_part else None
cross_rels = train_data.cross_rels if args.soft_rel_part else None
args.num_thread = NUM_THREAD
procs = []
barrier = mp.Barrier(args.num_client)
for i in range(args.num_client):
proc = mp.Process(target=dist_train_test, args=(args,
model,
train_samplers[i],
entity_partition_book,
relation_partition_book,
local2global,
i,
rel_parts,
cross_rels,
barrier))
procs.append(proc)
proc.start()
for proc in procs:
proc.join()
if __name__ == '__main__':
args = ArgParser().parse_args()
logger = get_logger(args)
start_worker(args, logger)
\ No newline at end of file
import os
import argparse
import time
import dgl
from dgl.contrib import KVServer
import torch as th
from train_pytorch import load_model
from dataloader import get_server_partition_dataset
NUM_THREAD = 1 # Fix the number of threads to 1 on kvstore
class KGEServer(KVServer):
"""User-defined kvstore for DGL-KGE
"""
def _push_handler(self, name, ID, data, target):
"""Row-Sparse Adagrad updater
"""
original_name = name[0:-6]
state_sum = target[original_name+'_state-data-']
grad_sum = (data * data).mean(1)
state_sum.index_add_(0, ID, grad_sum)
std = state_sum[ID] # _sparse_mask
std_values = std.sqrt_().add_(1e-10).unsqueeze(1)
tmp = (-self.clr * data / std_values)
target[name].index_add_(0, ID, tmp)
def set_clr(self, learning_rate):
"""Set learning rate for Row-Sparse Adagrad updater
"""
self.clr = learning_rate
# Note: Most of the args are unnecessary for KVStore, will remove them later
class ArgParser(argparse.ArgumentParser):
def __init__(self):
super(ArgParser, self).__init__()
self.add_argument('--model_name', default='TransE',
choices=['TransE', 'TransE_l1', 'TransE_l2', 'TransR',
'RESCAL', 'DistMult', 'ComplEx', 'RotatE'],
help='model to use')
self.add_argument('--data_path', type=str, default='../data',
help='root path of all dataset')
self.add_argument('--dataset', type=str, default='FB15k',
help='dataset name, under data_path')
self.add_argument('--format', type=str, default='1',
help='the format of the dataset.')
self.add_argument('--hidden_dim', type=int, default=256,
help='hidden dim used by relation and entity')
self.add_argument('--lr', type=float, default=0.0001,
help='learning rate')
self.add_argument('-g', '--gamma', type=float, default=12.0,
help='margin value')
self.add_argument('--gpu', type=int, default=[-1], nargs='+',
help='a list of active gpu ids, e.g. 0')
self.add_argument('--mix_cpu_gpu', action='store_true',
help='mix CPU and GPU training')
self.add_argument('-de', '--double_ent', action='store_true',
help='double entitiy dim for complex number')
self.add_argument('-dr', '--double_rel', action='store_true',
help='double relation dim for complex number')
self.add_argument('--seed', type=int, default=0,
help='set random seed for reproducibility')
self.add_argument('--rel_part', action='store_true',
help='enable relation partitioning')
self.add_argument('--soft_rel_part', action='store_true',
help='enable soft relation partition')
self.add_argument('--nomp_thread_per_process', type=int, default=-1,
help='num of omp threads used per process in multi-process training')
self.add_argument('--async_update', action='store_true',
help='allow async_update on node embedding')
self.add_argument('--strict_rel_part', action='store_true',
help='Strict relation partition')
self.add_argument('--server_id', type=int, default=0,
help='Unique ID of each server')
self.add_argument('--ip_config', type=str, default='ip_config.txt',
help='IP configuration file of kvstore')
self.add_argument('--total_client', type=int, default=1,
help='Total number of client worker nodes')
def get_server_data(args, machine_id):
"""Get data from data_path/dataset/part_machine_id
Return: glocal2local,
entity_emb,
entity_state,
relation_emb,
relation_emb_state
"""
g2l, dataset = get_server_partition_dataset(
args.data_path,
args.dataset,
args.format,
machine_id)
# Note that the dataset doesn't ccontain the triple
print('n_entities: ' + str(dataset.n_entities))
print('n_relations: ' + str(dataset.n_relations))
model = load_model(None, args, dataset.n_entities, dataset.n_relations)
return g2l, model.entity_emb.emb, model.entity_emb.state_sum, model.relation_emb.emb, model.relation_emb.state_sum
def start_server(args):
"""Start kvstore service
"""
th.set_num_threads(NUM_THREAD)
server_namebook = dgl.contrib.read_ip_config(filename=args.ip_config)
my_server = KGEServer(server_id=args.server_id,
server_namebook=server_namebook,
num_client=args.total_client)
my_server.set_clr(args.lr)
if my_server.get_id() % my_server.get_group_count() == 0: # master server
g2l, entity_emb, entity_emb_state, relation_emb, relation_emb_state = get_server_data(args, my_server.get_machine_id())
my_server.set_global2local(name='entity_emb', global2local=g2l)
my_server.init_data(name='relation_emb', data_tensor=relation_emb)
my_server.init_data(name='relation_emb_state', data_tensor=relation_emb_state)
my_server.init_data(name='entity_emb', data_tensor=entity_emb)
my_server.init_data(name='entity_emb_state', data_tensor=entity_emb_state)
else: # backup server
my_server.set_global2local(name='entity_emb')
my_server.init_data(name='relation_emb')
my_server.init_data(name='relation_emb_state')
my_server.init_data(name='entity_emb')
my_server.init_data(name='entity_emb_state')
print('KVServer %d listen for requests ...' % my_server.get_id())
my_server.start()
if __name__ == '__main__':
args = ArgParser().parse_args()
start_server(args)
\ No newline at end of file
......@@ -462,3 +462,36 @@ class KEModel(object):
"""Terminate the async update for entity embedding.
"""
self.entity_emb.finish_async_update()
def pull_model(self, client, pos_g, neg_g):
with th.no_grad():
entity_id = F.cat(seq=[pos_g.ndata['id'], neg_g.ndata['id']], dim=0)
relation_id = pos_g.edata['id']
entity_id = F.tensor(np.unique(F.asnumpy(entity_id)))
relation_id = F.tensor(np.unique(F.asnumpy(relation_id)))
l2g = client.get_local2global()
global_entity_id = l2g[entity_id]
entity_data = client.pull(name='entity_emb', id_tensor=global_entity_id)
relation_data = client.pull(name='relation_emb', id_tensor=relation_id)
self.entity_emb.emb[entity_id] = entity_data
self.relation_emb.emb[relation_id] = relation_data
def push_gradient(self, client):
with th.no_grad():
l2g = client.get_local2global()
for entity_id, entity_data in self.entity_emb.trace:
grad = entity_data.grad.data
global_entity_id =l2g[entity_id]
client.push(name='entity_emb', id_tensor=global_entity_id, data_tensor=grad)
for relation_id, relation_data in self.relation_emb.trace:
grad = relation_data.grad.data
client.push(name='relation_emb', id_tensor=relation_id, data_tensor=grad)
self.entity_emb.trace = []
self.relation_emb.trace = []
\ No newline at end of file
......@@ -110,6 +110,10 @@ class ArgParser(argparse.ArgumentParser):
help='pickle built graph, building a huge graph is slow.')
self.add_argument('--num_proc', type=int, default=1,
help='number of process used')
self.add_argument('--num_test_proc', type=int, default=1,
help='number of process used for test')
self.add_argument('--num_thread', type=int, default=1,
help='number of thread used')
self.add_argument('--rel_part', action='store_true',
help='enable relation partitioning')
self.add_argument('--soft_rel_part', action='store_true',
......@@ -149,6 +153,7 @@ def get_logger(args):
def run(args, logger):
train_time_start = time.time()
# load dataset and samplers
dataset = get_dataset(args.data_path, args.dataset, args.format)
n_entities = dataset.n_entities
......@@ -185,7 +190,7 @@ def run(args, logger):
args.num_thread = 4
else:
# CPU training
args.num_thread = mp.cpu_count() // args.num_proc + 1
args.num_thread = 1
else:
args.num_thread = args.nomp_thread_per_process
......@@ -235,7 +240,10 @@ def run(args, logger):
if args.num_proc > 1:
num_workers = 1
if args.valid or args.test:
args.num_test_proc = args.num_proc if args.num_proc < len(args.gpu) else len(args.gpu)
if len(args.gpu) > 1:
args.num_test_proc = args.num_proc if args.num_proc < len(args.gpu) else len(args.gpu)
else:
args.num_test_proc = args.num_proc
eval_dataset = EvalDataset(dataset, args)
if args.valid:
# Here we want to use the regualr negative sampler because we need to ensure that
......@@ -324,6 +332,8 @@ def run(args, logger):
if args.num_proc > 1 or args.async_update:
model.share_memory()
print('Total data loading time {:.3f} seconds'.format(time.time() - train_time_start))
# train
start = time.time()
rel_parts = train_data.rel_parts if args.strict_rel_part or args.soft_rel_part else None
......
from models import KEModel
import torch.multiprocessing as mp
from torch.utils.data import DataLoader
import torch.optim as optim
import torch as th
......@@ -15,6 +16,66 @@ import logging
import time
from functools import wraps
import dgl
from dgl.contrib import KVClient
import dgl.backend as F
from dataloader import EvalDataset
from dataloader import get_dataset
class KGEClient(KVClient):
"""User-defined kvclient for DGL-KGE
"""
def _push_handler(self, name, ID, data, target):
"""Row-Sparse Adagrad updater
"""
original_name = name[0:-6]
state_sum = target[original_name+'_state-data-']
grad_sum = (data * data).mean(1)
state_sum.index_add_(0, ID, grad_sum)
std = state_sum[ID] # _sparse_mask
std_values = std.sqrt_().add_(1e-10).unsqueeze(1)
tmp = (-self.clr * data / std_values)
target[name].index_add_(0, ID, tmp)
def set_clr(self, learning_rate):
"""Set learning rate
"""
self.clr = learning_rate
def set_local2global(self, l2g):
self._l2g = l2g
def get_local2global(self):
return self._l2g
def connect_to_kvstore(args, entity_pb, relation_pb, l2g):
"""Create kvclient and connect to kvstore service
"""
server_namebook = dgl.contrib.read_ip_config(filename=args.ip_config)
my_client = KGEClient(server_namebook=server_namebook)
my_client.set_clr(args.lr)
my_client.connect()
if my_client.get_id() % args.num_client == 0:
my_client.set_partition_book(name='entity_emb', partition_book=entity_pb)
my_client.set_partition_book(name='relation_emb', partition_book=relation_pb)
else:
my_client.set_partition_book(name='entity_emb')
my_client.set_partition_book(name='relation_emb')
my_client.set_local2global(l2g)
return my_client
def load_model(logger, args, n_entities, n_relations, ckpt=None):
model = KEModel(args, args.model_name, n_entities, n_relations,
args.hidden_dim, args.gamma,
......@@ -29,7 +90,7 @@ def load_model_from_checkpoint(logger, args, n_entities, n_relations, ckpt_path)
model.load_emb(ckpt_path, args.dataset)
return model
def train(args, model, train_sampler, valid_samplers=None, rank=0, rel_parts=None, cross_rels=None, barrier=None):
def train(args, model, train_sampler, valid_samplers=None, rank=0, rel_parts=None, cross_rels=None, barrier=None, client=None):
logs = []
for arg in vars(args):
logging.info('{:20}:{}'.format(arg, getattr(args, arg)))
......@@ -46,7 +107,7 @@ def train(args, model, train_sampler, valid_samplers=None, rank=0, rel_parts=Non
if args.soft_rel_part:
model.prepare_cross_rels(cross_rels)
start = time.time()
train_start = start = time.time()
sample_time = 0
update_time = 0
forward_time = 0
......@@ -57,6 +118,9 @@ def train(args, model, train_sampler, valid_samplers=None, rank=0, rel_parts=Non
sample_time += time.time() - start1
args.step = step
if client is not None:
model.pull_model(client, pos_g, neg_g)
start1 = time.time()
loss, log = model.forward(pos_g, neg_g, gpu_id)
forward_time += time.time() - start1
......@@ -66,7 +130,10 @@ def train(args, model, train_sampler, valid_samplers=None, rank=0, rel_parts=Non
backward_time += time.time() - start1
start1 = time.time()
model.update(gpu_id)
if client is not None:
model.push_gradient(client)
else:
model.update(gpu_id)
update_time += time.time() - start1
logs.append(log)
......@@ -104,7 +171,7 @@ def train(args, model, train_sampler, valid_samplers=None, rank=0, rel_parts=Non
if barrier is not None:
barrier.wait()
print('train {} takes {:.3f} seconds'.format(rank, time.time() - start))
print('train {} takes {:.3f} seconds'.format(rank, time.time() - train_start))
if args.async_update:
model.finish_async_update()
if args.strict_rel_part or args.soft_rel_part:
......@@ -147,3 +214,129 @@ def train_mp(args, model, train_sampler, valid_samplers=None, rank=0, rel_parts=
@thread_wrapped_func
def test_mp(args, model, test_samplers, rank=0, mode='Test', queue=None):
test(args, model, test_samplers, rank, mode, queue)
@thread_wrapped_func
def dist_train_test(args, model, train_sampler, entity_pb, relation_pb, l2g, rank=0, rel_parts=None, cross_rels=None, barrier=None):
if args.num_proc > 1:
th.set_num_threads(args.num_thread)
client = connect_to_kvstore(args, entity_pb, relation_pb, l2g)
client.barrier()
train_time_start = time.time()
train(args, model, train_sampler, None, rank, rel_parts, cross_rels, barrier, client)
client.barrier()
print('Total train time {:.3f} seconds'.format(time.time() - train_time_start))
model = None
if client.get_id() % args.num_client == 0: # pull full model from kvstore
args.num_test_proc = args.num_client
dataset_full = get_dataset(args.data_path, args.dataset, args.format)
print('Full data n_entities: ' + str(dataset_full.n_entities))
print("Full data n_relations: " + str(dataset_full.n_relations))
model_test = load_model(None, args, dataset_full.n_entities, dataset_full.n_relations)
eval_dataset = EvalDataset(dataset_full, args)
if args.test:
model_test.share_memory()
if args.neg_sample_size_test < 0:
args.neg_sample_size_test = dataset_full.n_entities
args.eval_filter = not args.no_eval_filter
if args.neg_deg_sample_eval:
assert not args.eval_filter, "if negative sampling based on degree, we can't filter positive edges."
if args.neg_chunk_size_valid < 0:
args.neg_chunk_size_valid = args.neg_sample_size_valid
if args.neg_chunk_size_test < 0:
args.neg_chunk_size_test = args.neg_sample_size_test
print("Pull relation_emb ...")
relation_id = F.arange(0, model_test.n_relations)
relation_data = client.pull(name='relation_emb', id_tensor=relation_id)
model_test.relation_emb.emb[relation_id] = relation_data
print("Pull entity_emb ... ")
# split model into 100 small parts
start = 0
percent = 0
entity_id = F.arange(0, model_test.n_entities)
count = int(model_test.n_entities / 100)
end = start + count
while True:
print("Pull %d / 100 ..." % percent)
if end >= model_test.n_entities:
end = -1
tmp_id = entity_id[start:end]
entity_data = client.pull(name='entity_emb', id_tensor=tmp_id)
model_test.entity_emb.emb[tmp_id] = entity_data
if end == -1:
break
start = end
end += count
percent += 1
if args.save_emb is not None:
if not os.path.exists(args.save_emb):
os.mkdir(args.save_emb)
model_test.save_emb(args.save_emb, args.dataset)
if args.test:
args.num_thread = 1
test_sampler_tails = []
test_sampler_heads = []
for i in range(args.num_test_proc):
test_sampler_head = eval_dataset.create_sampler('test', args.batch_size_eval,
args.neg_sample_size_test,
args.neg_chunk_size_test,
args.eval_filter,
mode='chunk-head',
num_workers=args.num_thread,
rank=i, ranks=args.num_test_proc)
test_sampler_tail = eval_dataset.create_sampler('test', args.batch_size_eval,
args.neg_sample_size_test,
args.neg_chunk_size_test,
args.eval_filter,
mode='chunk-tail',
num_workers=args.num_thread,
rank=i, ranks=args.num_test_proc)
test_sampler_heads.append(test_sampler_head)
test_sampler_tails.append(test_sampler_tail)
eval_dataset = None
dataset_full = None
print("Run test, test processes: %d" % args.num_test_proc)
queue = mp.Queue(args.num_test_proc)
procs = []
for i in range(args.num_test_proc):
proc = mp.Process(target=test_mp, args=(args,
model_test,
[test_sampler_heads[i], test_sampler_tails[i]],
i,
'Test',
queue))
procs.append(proc)
proc.start()
total_metrics = {}
metrics = {}
logs = []
for i in range(args.num_test_proc):
log = queue.get()
logs = logs + log
for metric in logs[0].keys():
metrics[metric] = sum([log[metric] for log in logs]) / len(logs)
for k, v in metrics.items():
print('Test average {} at [{}/{}]: {}'.format(k, args.step, args.max_step, v))
for proc in procs:
proc.join()
if client.get_id() == 0:
client.shut_down()
\ No newline at end of file
......@@ -72,7 +72,7 @@ def read_ip_config(filename):
server_id += 1
machine_id += 1
except:
print("Error: data format on each line should be: [machine_id] [ip] [base_port] [server_count]")
print("Error: data format on each line should be: [ip] [base_port] [server_count]")
return server_namebook
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment