"tests/python/pytorch/sparse/test_reduction.py" did not exist on "f62669b05c73395404d0eb281d7657fb0f84790a"
Unverified Commit 00ba4094 authored by Chao Ma's avatar Chao Ma Committed by GitHub
Browse files

[DGL-KE] Distributed training of DGL-KE (#1290)

* update

* change name

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* change worker number

* update

* update

* update

* update

* update

* update

* test

* update

* update

* update

* remove barrier

* max_step

* update

* add complex

* update

* chmod +x

* update

* update

* random partition

* random partition

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* change num_test_proc

* update num_thread

* update
parent c3a33407
...@@ -37,7 +37,7 @@ class KGDataset1: ...@@ -37,7 +37,7 @@ class KGDataset1:
The triples are stored as 'head_name\trelation_name\ttail_name'. The triples are stored as 'head_name\trelation_name\ttail_name'.
''' '''
def __init__(self, path, name): def __init__(self, path, name, read_triple=True, only_train=False):
url = 'https://s3.us-east-2.amazonaws.com/dgl.ai/dataset/{}.zip'.format(name) url = 'https://s3.us-east-2.amazonaws.com/dgl.ai/dataset/{}.zip'.format(name)
if not os.path.exists(os.path.join(path, name)): if not os.path.exists(os.path.join(path, name)):
...@@ -66,9 +66,11 @@ class KGDataset1: ...@@ -66,9 +66,11 @@ class KGDataset1:
self.n_entities = len(self.entity2id) self.n_entities = len(self.entity2id)
self.n_relations = len(self.relation2id) self.n_relations = len(self.relation2id)
self.train = self.read_triple(path, 'train') if read_triple == True:
self.valid = self.read_triple(path, 'valid') self.train = self.read_triple(path, 'train')
self.test = self.read_triple(path, 'test') if only_train == False:
self.valid = self.read_triple(path, 'valid')
self.test = self.read_triple(path, 'test')
def read_triple(self, path, mode): def read_triple(self, path, mode):
# mode: train/valid/test # mode: train/valid/test
...@@ -102,7 +104,7 @@ class KGDataset2: ...@@ -102,7 +104,7 @@ class KGDataset2:
The triples are stored as 'head_nid\trelation_id\ttail_nid'. The triples are stored as 'head_nid\trelation_id\ttail_nid'.
''' '''
def __init__(self, path, name): def __init__(self, path, name, read_triple=True, only_train=False):
url = 'https://s3.us-east-2.amazonaws.com/dgl.ai/dataset/{}.zip'.format(name) url = 'https://s3.us-east-2.amazonaws.com/dgl.ai/dataset/{}.zip'.format(name)
if not os.path.exists(os.path.join(path, name)): if not os.path.exists(os.path.join(path, name)):
...@@ -110,17 +112,24 @@ class KGDataset2: ...@@ -110,17 +112,24 @@ class KGDataset2:
_download_and_extract(url, path, '{}.zip'.format(name)) _download_and_extract(url, path, '{}.zip'.format(name))
self.path = os.path.join(path, name) self.path = os.path.join(path, name)
f_ent2id = os.path.join(self.path, 'entity2id.txt')
f_rel2id = os.path.join(self.path, 'relation2id.txt') f_rel2id = os.path.join(self.path, 'relation2id.txt')
with open(f_ent2id) as f_ent:
self.n_entities = int(f_ent.readline()[:-1])
with open(f_rel2id) as f_rel: with open(f_rel2id) as f_rel:
self.n_relations = int(f_rel.readline()[:-1]) self.n_relations = int(f_rel.readline()[:-1])
self.train = self.read_triple(self.path, 'train') if only_train == True:
self.valid = self.read_triple(self.path, 'valid') f_ent2id = os.path.join(self.path, 'local_to_global.txt')
self.test = self.read_triple(self.path, 'test') with open(f_ent2id) as f_ent:
self.n_entities = len(f_ent.readlines())
else:
f_ent2id = os.path.join(self.path, 'entity2id.txt')
with open(f_ent2id) as f_ent:
self.n_entities = int(f_ent.readline()[:-1])
if read_triple == True:
self.train = self.read_triple(self.path, 'train')
if only_train == False:
self.valid = self.read_triple(self.path, 'valid')
self.test = self.read_triple(self.path, 'test')
def read_triple(self, path, mode, skip_first_line=False): def read_triple(self, path, mode, skip_first_line=False):
heads = [] heads = []
...@@ -151,3 +160,57 @@ def get_dataset(data_path, data_name, format_str): ...@@ -151,3 +160,57 @@ def get_dataset(data_path, data_name, format_str):
dataset = KGDataset2(data_path, data_name) dataset = KGDataset2(data_path, data_name)
return dataset return dataset
def get_partition_dataset(data_path, data_name, format_str, part_id):
part_name = os.path.join(data_name, 'part_'+str(part_id))
if data_name == 'Freebase':
dataset = KGDataset2(data_path, part_name, read_triple=True, only_train=True)
elif format_str == '1':
dataset = KGDataset1(data_path, part_name, read_triple=True, only_train=True)
else:
dataset = KGDataset2(data_path, part_name, read_triple=True, only_train=True)
path = os.path.join(data_path, part_name)
partition_book = []
with open(os.path.join(path, 'partition_book.txt')) as f:
for line in f:
partition_book.append(int(line))
local_to_global = []
with open(os.path.join(path, 'local_to_global.txt')) as f:
for line in f:
local_to_global.append(int(line))
return dataset, partition_book, local_to_global
def get_server_partition_dataset(data_path, data_name, format_str, part_id):
part_name = os.path.join(data_name, 'part_'+str(part_id))
if data_name == 'Freebase':
dataset = KGDataset2(data_path, part_name, read_triple=False, only_train=True)
elif format_str == '1':
dataset = KGDataset1(data_path, part_name, read_triple=False, only_train=True)
else:
dataset = KGDataset2(data_path, part_name, read_triple=False, only_train=True)
path = os.path.join(data_path, part_name)
n_entities = len(open(os.path.join(path, 'partition_book.txt')).readlines())
local_to_global = []
with open(os.path.join(path, 'local_to_global.txt')) as f:
for line in f:
local_to_global.append(int(line))
global_to_local = [0] * n_entities
for i in range(len(local_to_global)):
global_id = local_to_global[i]
global_to_local[global_id] = i
local_to_global = None
return global_to_local, dataset
##################################################################################
# This script runing ComplEx model on Freebase dataset in distributed setting.
# You can change the hyper-parameter in this file but DO NOT run script manually
##################################################################################
machine_id=$1
server_count=$2
# Delete the temp file
rm *-shape
##################################################################################
# Start kvserver
##################################################################################
SERVER_ID_LOW=$((machine_id*server_count))
SERVER_ID_HIGH=$(((machine_id+1)*server_count))
while [ $SERVER_ID_LOW -lt $SERVER_ID_HIGH ]
do
MKL_NUM_THREADS=1 OMP_NUM_THREADS=1 DGLBACKEND=pytorch python3 ../kvserver.py --model ComplEx --dataset Freebase \
--hidden_dim 400 --gamma 143.0 --lr 0.1 --total_client 160 --server_id $SERVER_ID_LOW &
let SERVER_ID_LOW+=1
done
##################################################################################
# Start kvclient
##################################################################################
MKL_NUM_THREADS=1 OMP_NUM_THREADS=1 DGLBACKEND=pytorch python3 ../kvclient.py --model ComplEx --dataset Freebase \
--batch_size 1024 --neg_sample_size 256 --hidden_dim 400 --gamma 143.0 --lr 0.1 --max_step 12500 --log_interval 100 \
--batch_size_eval 1000 --neg_sample_size_test 1000 --test -adv --total_machine 4 --num_client 40
\ No newline at end of file
##################################################################################
# This script runing distmult model on Freebase dataset in distributed setting.
# You can change the hyper-parameter in this file but DO NOT run script manually
##################################################################################
machine_id=$1
server_count=$2
# Delete the temp file
rm *-shape
##################################################################################
# Start kvserver
##################################################################################
SERVER_ID_LOW=$((machine_id*server_count))
SERVER_ID_HIGH=$(((machine_id+1)*server_count))
while [ $SERVER_ID_LOW -lt $SERVER_ID_HIGH ]
do
MKL_NUM_THREADS=1 OMP_NUM_THREADS=1 DGLBACKEND=pytorch python3 ../kvserver.py --model DistMult --dataset Freebase \
--hidden_dim 400 --gamma 143.0 --lr 0.08 --total_client 160 --server_id $SERVER_ID_LOW &
let SERVER_ID_LOW+=1
done
##################################################################################
# Start kvclient
##################################################################################
MKL_NUM_THREADS=1 OMP_NUM_THREADS=1 DGLBACKEND=pytorch python3 ../kvclient.py --model DistMult --dataset Freebase \
--batch_size 1024 --neg_sample_size 256 --hidden_dim 400 --gamma 143.0 --lr 0.08 --max_step 12500 --log_interval 100 \
--batch_size_eval 1000 --neg_sample_size_test 1000 --test -adv --total_machine 4 --num_client 40
\ No newline at end of file
##################################################################################
# This script runing distmult model on Freebase dataset in distributed setting.
# You can change the hyper-parameter in this file but DO NOT run script manually
##################################################################################
machine_id=$1
server_count=$2
# Delete the temp file
rm *-shape
##################################################################################
# Start kvserver
##################################################################################
SERVER_ID_LOW=$((machine_id*server_count))
SERVER_ID_HIGH=$(((machine_id+1)*server_count))
while [ $SERVER_ID_LOW -lt $SERVER_ID_HIGH ]
do
MKL_NUM_THREADS=1 OMP_NUM_THREADS=1 DGLBACKEND=pytorch python3 ../kvserver.py --model TransE_l2 --dataset Freebase \
--hidden_dim 400 --gamma 10 --lr 0.1 --total_client 160 --server_id $SERVER_ID_LOW &
let SERVER_ID_LOW+=1
done
##################################################################################
# Start kvclient
##################################################################################
MKL_NUM_THREADS=1 OMP_NUM_THREADS=1 DGLBACKEND=pytorch python3 ../kvclient.py --model TransE_l2 --dataset Freebase \
--batch_size 1000 --neg_sample_size 200 --hidden_dim 400 --gamma 10 --lr 0.1 --max_step 12500 --log_interval 100 \
--batch_size_eval 1000 --neg_sample_size_test 1000 --test -adv --regularization_coef 1e-9 --total_machine 4 --num_client 40
\ No newline at end of file
127.0.0.1 30050 8
127.0.0.1 30050 8
127.0.0.1 30050 8
127.0.0.1 30050 8
\ No newline at end of file
##################################################################################
# User runs this script to launch distrobited jobs on cluster
##################################################################################
script_path=~/dgl/apps/kg/distributed
script_file=./freebase_transe_l2.sh
user_name=ubuntu
ssh_key=~/mctt.pem
server_count=$(awk 'NR==1 {print $3}' ip_config.txt)
# run command on remote machine
LINE_LOW=2
LINE_HIGH=$(awk 'END{print NR}' ip_config.txt)
let LINE_HIGH+=1
s_id=0
while [ $LINE_LOW -lt $LINE_HIGH ]
do
ip=$(awk 'NR=='$LINE_LOW' {print $1}' ip_config.txt)
let LINE_LOW+=1
let s_id+=1
ssh -i $ssh_key $user_name@$ip 'cd '$script_path'; '$script_file' '$s_id' '$server_count' ' &
done
# run command on local machine
$script_file 0 $server_count
\ No newline at end of file
...@@ -66,6 +66,8 @@ class ArgParser(argparse.ArgumentParser): ...@@ -66,6 +66,8 @@ class ArgParser(argparse.ArgumentParser):
help='number of workers used for loading data') help='number of workers used for loading data')
self.add_argument('--num_proc', type=int, default=1, self.add_argument('--num_proc', type=int, default=1,
help='number of process used') help='number of process used')
self.add_argument('--num_thread', type=int, default=1,
help='number of thread used')
def parse_args(self): def parse_args(self):
args = super().parse_args() args = super().parse_args()
......
import os
import argparse
import time
import logging
import socket
if os.name != 'nt':
import fcntl
import struct
import torch.multiprocessing as mp
from train_pytorch import load_model, dist_train_test
from train import get_logger
from dataloader import TrainDataset, NewBidirectionalOneShotIterator
from dataloader import get_dataset, get_partition_dataset
import dgl
import dgl.backend as F
NUM_THREAD = 1 # Fix the number of threads to 1 on kvclient
NUM_WORKER = 1 # Fix the number of worker for sampler to 1
class ArgParser(argparse.ArgumentParser):
def __init__(self):
super(ArgParser, self).__init__()
self.add_argument('--model_name', default='TransE',
choices=['TransE', 'TransE_l1', 'TransE_l2', 'TransR',
'RESCAL', 'DistMult', 'ComplEx', 'RotatE'],
help='model to use')
self.add_argument('--data_path', type=str, default='../data',
help='root path of all dataset')
self.add_argument('--dataset', type=str, default='FB15k',
help='dataset name, under data_path')
self.add_argument('--format', type=str, default='1',
help='the format of the dataset.')
self.add_argument('--save_path', type=str, default='ckpts',
help='place to save models and logs')
self.add_argument('--save_emb', type=str, default=None,
help='save the embeddings in the specific location.')
self.add_argument('--max_step', type=int, default=80000,
help='train xx steps')
self.add_argument('--warm_up_step', type=int, default=None,
help='for learning rate decay')
self.add_argument('--batch_size', type=int, default=1024,
help='batch size')
self.add_argument('--batch_size_eval', type=int, default=8,
help='batch size used for eval and test')
self.add_argument('--neg_sample_size', type=int, default=128,
help='negative sampling size')
self.add_argument('--neg_chunk_size', type=int, default=-1,
help='chunk size of the negative edges.')
self.add_argument('--neg_deg_sample', action='store_true',
help='negative sample proportional to vertex degree in the training')
self.add_argument('--neg_deg_sample_eval', action='store_true',
help='negative sampling proportional to vertex degree in the evaluation')
self.add_argument('--neg_sample_size_valid', type=int, default=1000,
help='negative sampling size for validation')
self.add_argument('--neg_chunk_size_valid', type=int, default=-1,
help='chunk size of the negative edges.')
self.add_argument('--neg_sample_size_test', type=int, default=-1,
help='negative sampling size for testing')
self.add_argument('--neg_chunk_size_test', type=int, default=-1,
help='chunk size of the negative edges.')
self.add_argument('--hidden_dim', type=int, default=256,
help='hidden dim used by relation and entity')
self.add_argument('--lr', type=float, default=0.0001,
help='learning rate')
self.add_argument('-g', '--gamma', type=float, default=12.0,
help='margin value')
self.add_argument('--eval_percent', type=float, default=1,
help='sample some percentage for evaluation.')
self.add_argument('--no_eval_filter', action='store_true',
help='do not filter positive edges among negative edges for evaluation')
self.add_argument('--gpu', type=int, default=[-1], nargs='+',
help='a list of active gpu ids, e.g. 0 1 2 4')
self.add_argument('--mix_cpu_gpu', action='store_true',
help='mix CPU and GPU training')
self.add_argument('-de', '--double_ent', action='store_true',
help='double entitiy dim for complex number')
self.add_argument('-dr', '--double_rel', action='store_true',
help='double relation dim for complex number')
self.add_argument('--seed', type=int, default=0,
help='set random seed fro reproducibility')
self.add_argument('-log', '--log_interval', type=int, default=1000,
help='do evaluation after every x steps')
self.add_argument('--eval_interval', type=int, default=10000,
help='do evaluation after every x steps')
self.add_argument('-adv', '--neg_adversarial_sampling', action='store_true',
help='if use negative adversarial sampling')
self.add_argument('-a', '--adversarial_temperature', default=1.0, type=float)
self.add_argument('--valid', action='store_true',
help='if valid a model')
self.add_argument('--test', action='store_true',
help='if test a model')
self.add_argument('-rc', '--regularization_coef', type=float, default=0.000002,
help='set value > 0.0 if regularization is used')
self.add_argument('-rn', '--regularization_norm', type=int, default=3,
help='norm used in regularization')
self.add_argument('--num_worker', type=int, default=32,
help='number of workers used for loading data')
self.add_argument('--non_uni_weight', action='store_true',
help='if use uniform weight when computing loss')
self.add_argument('--init_step', type=int, default=0,
help='DONT SET MANUALLY, used for resume')
self.add_argument('--step', type=int, default=0,
help='DONT SET MANUALLY, track current step')
self.add_argument('--pickle_graph', action='store_true',
help='pickle built graph, building a huge graph is slow.')
self.add_argument('--num_proc', type=int, default=1,
help='number of process used')
self.add_argument('--rel_part', action='store_true',
help='enable relation partitioning')
self.add_argument('--soft_rel_part', action='store_true',
help='enable soft relation partition')
self.add_argument('--nomp_thread_per_process', type=int, default=-1,
help='num of omp threads used per process in multi-process training')
self.add_argument('--async_update', action='store_true',
help='allow async_update on node embedding')
self.add_argument('--force_sync_interval', type=int, default=-1,
help='We force a synchronization between processes every x steps')
self.add_argument('--strict_rel_part', action='store_true',
help='Strict relation partition')
self.add_argument('--machine_id', type=int, default=0,
help='Unique ID of current machine.')
self.add_argument('--total_machine', type=int, default=1,
help='Total number of machine.')
self.add_argument('--ip_config', type=str, default='ip_config.txt',
help='IP configuration file of kvstore')
self.add_argument('--num_client', type=int, default=1,
help='Number of client on each machine.')
def get_long_tail_partition(n_relations, n_machine):
"""Relation types has a long tail distribution for many dataset.
So we need to average shuffle the data before we partition it.
"""
assert n_relations > 0, 'n_relations must be a positive number.'
assert n_machine > 0, 'n_machine must be a positive number.'
partition_book = [0] * n_relations
part_id = 0
for i in range(n_relations):
partition_book[i] = part_id
part_id += 1
if part_id == n_machine:
part_id = 0
return partition_book
def local_ip4_addr_list():
"""Return a set of IPv4 address
"""
nic = set()
for ix in socket.if_nameindex():
name = ix[1]
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
ip = socket.inet_ntoa(fcntl.ioctl(
s.fileno(),
0x8915, # SIOCGIFADDR
struct.pack('256s', name[:15].encode("UTF-8")))[20:24])
nic.add(ip)
return nic
def get_local_machine_id(server_namebook):
"""Get machine ID via server_namebook
"""
assert len(server_namebook) > 0, 'server_namebook cannot be empty.'
res = 0
for ID, data in server_namebook.items():
machine_id = data[0]
ip = data[1]
if ip in local_ip4_addr_list():
res = machine_id
break
return res
def start_worker(args, logger):
"""Start kvclient for training
"""
train_time_start = time.time()
server_namebook = dgl.contrib.read_ip_config(filename=args.ip_config)
args.machine_id = get_local_machine_id(server_namebook)
dataset, entity_partition_book, local2global = get_partition_dataset(
args.data_path,
args.dataset,
args.format,
args.machine_id)
n_entities = dataset.n_entities
n_relations = dataset.n_relations
print('Partition %d n_entities: %d' % (args.machine_id, n_entities))
print("Partition %d n_relations: %d" % (args.machine_id, n_relations))
entity_partition_book = F.tensor(entity_partition_book)
relation_partition_book = get_long_tail_partition(dataset.n_relations, args.total_machine)
relation_partition_book = F.tensor(relation_partition_book)
local2global = F.tensor(local2global)
relation_partition_book.share_memory_()
entity_partition_book.share_memory_()
local2global.share_memory_()
model = load_model(logger, args, n_entities, n_relations)
model.share_memory()
# When we generate a batch of negative edges from a set of positive edges,
# we first divide the positive edges into chunks and corrupt the edges in a chunk
# together. By default, the chunk size is equal to the negative sample size.
# Usually, this works well. But we also allow users to specify the chunk size themselves.
if args.neg_chunk_size < 0:
args.neg_chunk_size = args.neg_sample_size
num_workers = NUM_WORKER
train_data = TrainDataset(dataset, args, ranks=args.num_client)
train_samplers = []
for i in range(args.num_client):
train_sampler_head = train_data.create_sampler(args.batch_size,
args.neg_sample_size,
args.neg_chunk_size,
mode='head',
num_workers=num_workers,
shuffle=True,
exclude_positive=False,
rank=i)
train_sampler_tail = train_data.create_sampler(args.batch_size,
args.neg_sample_size,
args.neg_chunk_size,
mode='tail',
num_workers=num_workers,
shuffle=True,
exclude_positive=False,
rank=i)
train_samplers.append(NewBidirectionalOneShotIterator(train_sampler_head, train_sampler_tail,
args.neg_chunk_size, args.neg_sample_size,
True, n_entities))
dataset = None
print('Total data loading time {:.3f} seconds'.format(time.time() - train_time_start))
rel_parts = train_data.rel_parts if args.strict_rel_part or args.soft_rel_part else None
cross_rels = train_data.cross_rels if args.soft_rel_part else None
args.num_thread = NUM_THREAD
procs = []
barrier = mp.Barrier(args.num_client)
for i in range(args.num_client):
proc = mp.Process(target=dist_train_test, args=(args,
model,
train_samplers[i],
entity_partition_book,
relation_partition_book,
local2global,
i,
rel_parts,
cross_rels,
barrier))
procs.append(proc)
proc.start()
for proc in procs:
proc.join()
if __name__ == '__main__':
args = ArgParser().parse_args()
logger = get_logger(args)
start_worker(args, logger)
\ No newline at end of file
import os
import argparse
import time
import dgl
from dgl.contrib import KVServer
import torch as th
from train_pytorch import load_model
from dataloader import get_server_partition_dataset
NUM_THREAD = 1 # Fix the number of threads to 1 on kvstore
class KGEServer(KVServer):
"""User-defined kvstore for DGL-KGE
"""
def _push_handler(self, name, ID, data, target):
"""Row-Sparse Adagrad updater
"""
original_name = name[0:-6]
state_sum = target[original_name+'_state-data-']
grad_sum = (data * data).mean(1)
state_sum.index_add_(0, ID, grad_sum)
std = state_sum[ID] # _sparse_mask
std_values = std.sqrt_().add_(1e-10).unsqueeze(1)
tmp = (-self.clr * data / std_values)
target[name].index_add_(0, ID, tmp)
def set_clr(self, learning_rate):
"""Set learning rate for Row-Sparse Adagrad updater
"""
self.clr = learning_rate
# Note: Most of the args are unnecessary for KVStore, will remove them later
class ArgParser(argparse.ArgumentParser):
def __init__(self):
super(ArgParser, self).__init__()
self.add_argument('--model_name', default='TransE',
choices=['TransE', 'TransE_l1', 'TransE_l2', 'TransR',
'RESCAL', 'DistMult', 'ComplEx', 'RotatE'],
help='model to use')
self.add_argument('--data_path', type=str, default='../data',
help='root path of all dataset')
self.add_argument('--dataset', type=str, default='FB15k',
help='dataset name, under data_path')
self.add_argument('--format', type=str, default='1',
help='the format of the dataset.')
self.add_argument('--hidden_dim', type=int, default=256,
help='hidden dim used by relation and entity')
self.add_argument('--lr', type=float, default=0.0001,
help='learning rate')
self.add_argument('-g', '--gamma', type=float, default=12.0,
help='margin value')
self.add_argument('--gpu', type=int, default=[-1], nargs='+',
help='a list of active gpu ids, e.g. 0')
self.add_argument('--mix_cpu_gpu', action='store_true',
help='mix CPU and GPU training')
self.add_argument('-de', '--double_ent', action='store_true',
help='double entitiy dim for complex number')
self.add_argument('-dr', '--double_rel', action='store_true',
help='double relation dim for complex number')
self.add_argument('--seed', type=int, default=0,
help='set random seed for reproducibility')
self.add_argument('--rel_part', action='store_true',
help='enable relation partitioning')
self.add_argument('--soft_rel_part', action='store_true',
help='enable soft relation partition')
self.add_argument('--nomp_thread_per_process', type=int, default=-1,
help='num of omp threads used per process in multi-process training')
self.add_argument('--async_update', action='store_true',
help='allow async_update on node embedding')
self.add_argument('--strict_rel_part', action='store_true',
help='Strict relation partition')
self.add_argument('--server_id', type=int, default=0,
help='Unique ID of each server')
self.add_argument('--ip_config', type=str, default='ip_config.txt',
help='IP configuration file of kvstore')
self.add_argument('--total_client', type=int, default=1,
help='Total number of client worker nodes')
def get_server_data(args, machine_id):
"""Get data from data_path/dataset/part_machine_id
Return: glocal2local,
entity_emb,
entity_state,
relation_emb,
relation_emb_state
"""
g2l, dataset = get_server_partition_dataset(
args.data_path,
args.dataset,
args.format,
machine_id)
# Note that the dataset doesn't ccontain the triple
print('n_entities: ' + str(dataset.n_entities))
print('n_relations: ' + str(dataset.n_relations))
model = load_model(None, args, dataset.n_entities, dataset.n_relations)
return g2l, model.entity_emb.emb, model.entity_emb.state_sum, model.relation_emb.emb, model.relation_emb.state_sum
def start_server(args):
"""Start kvstore service
"""
th.set_num_threads(NUM_THREAD)
server_namebook = dgl.contrib.read_ip_config(filename=args.ip_config)
my_server = KGEServer(server_id=args.server_id,
server_namebook=server_namebook,
num_client=args.total_client)
my_server.set_clr(args.lr)
if my_server.get_id() % my_server.get_group_count() == 0: # master server
g2l, entity_emb, entity_emb_state, relation_emb, relation_emb_state = get_server_data(args, my_server.get_machine_id())
my_server.set_global2local(name='entity_emb', global2local=g2l)
my_server.init_data(name='relation_emb', data_tensor=relation_emb)
my_server.init_data(name='relation_emb_state', data_tensor=relation_emb_state)
my_server.init_data(name='entity_emb', data_tensor=entity_emb)
my_server.init_data(name='entity_emb_state', data_tensor=entity_emb_state)
else: # backup server
my_server.set_global2local(name='entity_emb')
my_server.init_data(name='relation_emb')
my_server.init_data(name='relation_emb_state')
my_server.init_data(name='entity_emb')
my_server.init_data(name='entity_emb_state')
print('KVServer %d listen for requests ...' % my_server.get_id())
my_server.start()
if __name__ == '__main__':
args = ArgParser().parse_args()
start_server(args)
\ No newline at end of file
...@@ -462,3 +462,36 @@ class KEModel(object): ...@@ -462,3 +462,36 @@ class KEModel(object):
"""Terminate the async update for entity embedding. """Terminate the async update for entity embedding.
""" """
self.entity_emb.finish_async_update() self.entity_emb.finish_async_update()
def pull_model(self, client, pos_g, neg_g):
with th.no_grad():
entity_id = F.cat(seq=[pos_g.ndata['id'], neg_g.ndata['id']], dim=0)
relation_id = pos_g.edata['id']
entity_id = F.tensor(np.unique(F.asnumpy(entity_id)))
relation_id = F.tensor(np.unique(F.asnumpy(relation_id)))
l2g = client.get_local2global()
global_entity_id = l2g[entity_id]
entity_data = client.pull(name='entity_emb', id_tensor=global_entity_id)
relation_data = client.pull(name='relation_emb', id_tensor=relation_id)
self.entity_emb.emb[entity_id] = entity_data
self.relation_emb.emb[relation_id] = relation_data
def push_gradient(self, client):
with th.no_grad():
l2g = client.get_local2global()
for entity_id, entity_data in self.entity_emb.trace:
grad = entity_data.grad.data
global_entity_id =l2g[entity_id]
client.push(name='entity_emb', id_tensor=global_entity_id, data_tensor=grad)
for relation_id, relation_data in self.relation_emb.trace:
grad = relation_data.grad.data
client.push(name='relation_emb', id_tensor=relation_id, data_tensor=grad)
self.entity_emb.trace = []
self.relation_emb.trace = []
\ No newline at end of file
...@@ -110,6 +110,10 @@ class ArgParser(argparse.ArgumentParser): ...@@ -110,6 +110,10 @@ class ArgParser(argparse.ArgumentParser):
help='pickle built graph, building a huge graph is slow.') help='pickle built graph, building a huge graph is slow.')
self.add_argument('--num_proc', type=int, default=1, self.add_argument('--num_proc', type=int, default=1,
help='number of process used') help='number of process used')
self.add_argument('--num_test_proc', type=int, default=1,
help='number of process used for test')
self.add_argument('--num_thread', type=int, default=1,
help='number of thread used')
self.add_argument('--rel_part', action='store_true', self.add_argument('--rel_part', action='store_true',
help='enable relation partitioning') help='enable relation partitioning')
self.add_argument('--soft_rel_part', action='store_true', self.add_argument('--soft_rel_part', action='store_true',
...@@ -149,6 +153,7 @@ def get_logger(args): ...@@ -149,6 +153,7 @@ def get_logger(args):
def run(args, logger): def run(args, logger):
train_time_start = time.time()
# load dataset and samplers # load dataset and samplers
dataset = get_dataset(args.data_path, args.dataset, args.format) dataset = get_dataset(args.data_path, args.dataset, args.format)
n_entities = dataset.n_entities n_entities = dataset.n_entities
...@@ -185,7 +190,7 @@ def run(args, logger): ...@@ -185,7 +190,7 @@ def run(args, logger):
args.num_thread = 4 args.num_thread = 4
else: else:
# CPU training # CPU training
args.num_thread = mp.cpu_count() // args.num_proc + 1 args.num_thread = 1
else: else:
args.num_thread = args.nomp_thread_per_process args.num_thread = args.nomp_thread_per_process
...@@ -235,7 +240,10 @@ def run(args, logger): ...@@ -235,7 +240,10 @@ def run(args, logger):
if args.num_proc > 1: if args.num_proc > 1:
num_workers = 1 num_workers = 1
if args.valid or args.test: if args.valid or args.test:
args.num_test_proc = args.num_proc if args.num_proc < len(args.gpu) else len(args.gpu) if len(args.gpu) > 1:
args.num_test_proc = args.num_proc if args.num_proc < len(args.gpu) else len(args.gpu)
else:
args.num_test_proc = args.num_proc
eval_dataset = EvalDataset(dataset, args) eval_dataset = EvalDataset(dataset, args)
if args.valid: if args.valid:
# Here we want to use the regualr negative sampler because we need to ensure that # Here we want to use the regualr negative sampler because we need to ensure that
...@@ -324,6 +332,8 @@ def run(args, logger): ...@@ -324,6 +332,8 @@ def run(args, logger):
if args.num_proc > 1 or args.async_update: if args.num_proc > 1 or args.async_update:
model.share_memory() model.share_memory()
print('Total data loading time {:.3f} seconds'.format(time.time() - train_time_start))
# train # train
start = time.time() start = time.time()
rel_parts = train_data.rel_parts if args.strict_rel_part or args.soft_rel_part else None rel_parts = train_data.rel_parts if args.strict_rel_part or args.soft_rel_part else None
......
from models import KEModel from models import KEModel
import torch.multiprocessing as mp
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import torch.optim as optim import torch.optim as optim
import torch as th import torch as th
...@@ -15,6 +16,66 @@ import logging ...@@ -15,6 +16,66 @@ import logging
import time import time
from functools import wraps from functools import wraps
import dgl
from dgl.contrib import KVClient
import dgl.backend as F
from dataloader import EvalDataset
from dataloader import get_dataset
class KGEClient(KVClient):
"""User-defined kvclient for DGL-KGE
"""
def _push_handler(self, name, ID, data, target):
"""Row-Sparse Adagrad updater
"""
original_name = name[0:-6]
state_sum = target[original_name+'_state-data-']
grad_sum = (data * data).mean(1)
state_sum.index_add_(0, ID, grad_sum)
std = state_sum[ID] # _sparse_mask
std_values = std.sqrt_().add_(1e-10).unsqueeze(1)
tmp = (-self.clr * data / std_values)
target[name].index_add_(0, ID, tmp)
def set_clr(self, learning_rate):
"""Set learning rate
"""
self.clr = learning_rate
def set_local2global(self, l2g):
self._l2g = l2g
def get_local2global(self):
return self._l2g
def connect_to_kvstore(args, entity_pb, relation_pb, l2g):
"""Create kvclient and connect to kvstore service
"""
server_namebook = dgl.contrib.read_ip_config(filename=args.ip_config)
my_client = KGEClient(server_namebook=server_namebook)
my_client.set_clr(args.lr)
my_client.connect()
if my_client.get_id() % args.num_client == 0:
my_client.set_partition_book(name='entity_emb', partition_book=entity_pb)
my_client.set_partition_book(name='relation_emb', partition_book=relation_pb)
else:
my_client.set_partition_book(name='entity_emb')
my_client.set_partition_book(name='relation_emb')
my_client.set_local2global(l2g)
return my_client
def load_model(logger, args, n_entities, n_relations, ckpt=None): def load_model(logger, args, n_entities, n_relations, ckpt=None):
model = KEModel(args, args.model_name, n_entities, n_relations, model = KEModel(args, args.model_name, n_entities, n_relations,
args.hidden_dim, args.gamma, args.hidden_dim, args.gamma,
...@@ -29,7 +90,7 @@ def load_model_from_checkpoint(logger, args, n_entities, n_relations, ckpt_path) ...@@ -29,7 +90,7 @@ def load_model_from_checkpoint(logger, args, n_entities, n_relations, ckpt_path)
model.load_emb(ckpt_path, args.dataset) model.load_emb(ckpt_path, args.dataset)
return model return model
def train(args, model, train_sampler, valid_samplers=None, rank=0, rel_parts=None, cross_rels=None, barrier=None): def train(args, model, train_sampler, valid_samplers=None, rank=0, rel_parts=None, cross_rels=None, barrier=None, client=None):
logs = [] logs = []
for arg in vars(args): for arg in vars(args):
logging.info('{:20}:{}'.format(arg, getattr(args, arg))) logging.info('{:20}:{}'.format(arg, getattr(args, arg)))
...@@ -46,7 +107,7 @@ def train(args, model, train_sampler, valid_samplers=None, rank=0, rel_parts=Non ...@@ -46,7 +107,7 @@ def train(args, model, train_sampler, valid_samplers=None, rank=0, rel_parts=Non
if args.soft_rel_part: if args.soft_rel_part:
model.prepare_cross_rels(cross_rels) model.prepare_cross_rels(cross_rels)
start = time.time() train_start = start = time.time()
sample_time = 0 sample_time = 0
update_time = 0 update_time = 0
forward_time = 0 forward_time = 0
...@@ -57,6 +118,9 @@ def train(args, model, train_sampler, valid_samplers=None, rank=0, rel_parts=Non ...@@ -57,6 +118,9 @@ def train(args, model, train_sampler, valid_samplers=None, rank=0, rel_parts=Non
sample_time += time.time() - start1 sample_time += time.time() - start1
args.step = step args.step = step
if client is not None:
model.pull_model(client, pos_g, neg_g)
start1 = time.time() start1 = time.time()
loss, log = model.forward(pos_g, neg_g, gpu_id) loss, log = model.forward(pos_g, neg_g, gpu_id)
forward_time += time.time() - start1 forward_time += time.time() - start1
...@@ -66,7 +130,10 @@ def train(args, model, train_sampler, valid_samplers=None, rank=0, rel_parts=Non ...@@ -66,7 +130,10 @@ def train(args, model, train_sampler, valid_samplers=None, rank=0, rel_parts=Non
backward_time += time.time() - start1 backward_time += time.time() - start1
start1 = time.time() start1 = time.time()
model.update(gpu_id) if client is not None:
model.push_gradient(client)
else:
model.update(gpu_id)
update_time += time.time() - start1 update_time += time.time() - start1
logs.append(log) logs.append(log)
...@@ -104,7 +171,7 @@ def train(args, model, train_sampler, valid_samplers=None, rank=0, rel_parts=Non ...@@ -104,7 +171,7 @@ def train(args, model, train_sampler, valid_samplers=None, rank=0, rel_parts=Non
if barrier is not None: if barrier is not None:
barrier.wait() barrier.wait()
print('train {} takes {:.3f} seconds'.format(rank, time.time() - start)) print('train {} takes {:.3f} seconds'.format(rank, time.time() - train_start))
if args.async_update: if args.async_update:
model.finish_async_update() model.finish_async_update()
if args.strict_rel_part or args.soft_rel_part: if args.strict_rel_part or args.soft_rel_part:
...@@ -147,3 +214,129 @@ def train_mp(args, model, train_sampler, valid_samplers=None, rank=0, rel_parts= ...@@ -147,3 +214,129 @@ def train_mp(args, model, train_sampler, valid_samplers=None, rank=0, rel_parts=
@thread_wrapped_func @thread_wrapped_func
def test_mp(args, model, test_samplers, rank=0, mode='Test', queue=None): def test_mp(args, model, test_samplers, rank=0, mode='Test', queue=None):
test(args, model, test_samplers, rank, mode, queue) test(args, model, test_samplers, rank, mode, queue)
@thread_wrapped_func
def dist_train_test(args, model, train_sampler, entity_pb, relation_pb, l2g, rank=0, rel_parts=None, cross_rels=None, barrier=None):
if args.num_proc > 1:
th.set_num_threads(args.num_thread)
client = connect_to_kvstore(args, entity_pb, relation_pb, l2g)
client.barrier()
train_time_start = time.time()
train(args, model, train_sampler, None, rank, rel_parts, cross_rels, barrier, client)
client.barrier()
print('Total train time {:.3f} seconds'.format(time.time() - train_time_start))
model = None
if client.get_id() % args.num_client == 0: # pull full model from kvstore
args.num_test_proc = args.num_client
dataset_full = get_dataset(args.data_path, args.dataset, args.format)
print('Full data n_entities: ' + str(dataset_full.n_entities))
print("Full data n_relations: " + str(dataset_full.n_relations))
model_test = load_model(None, args, dataset_full.n_entities, dataset_full.n_relations)
eval_dataset = EvalDataset(dataset_full, args)
if args.test:
model_test.share_memory()
if args.neg_sample_size_test < 0:
args.neg_sample_size_test = dataset_full.n_entities
args.eval_filter = not args.no_eval_filter
if args.neg_deg_sample_eval:
assert not args.eval_filter, "if negative sampling based on degree, we can't filter positive edges."
if args.neg_chunk_size_valid < 0:
args.neg_chunk_size_valid = args.neg_sample_size_valid
if args.neg_chunk_size_test < 0:
args.neg_chunk_size_test = args.neg_sample_size_test
print("Pull relation_emb ...")
relation_id = F.arange(0, model_test.n_relations)
relation_data = client.pull(name='relation_emb', id_tensor=relation_id)
model_test.relation_emb.emb[relation_id] = relation_data
print("Pull entity_emb ... ")
# split model into 100 small parts
start = 0
percent = 0
entity_id = F.arange(0, model_test.n_entities)
count = int(model_test.n_entities / 100)
end = start + count
while True:
print("Pull %d / 100 ..." % percent)
if end >= model_test.n_entities:
end = -1
tmp_id = entity_id[start:end]
entity_data = client.pull(name='entity_emb', id_tensor=tmp_id)
model_test.entity_emb.emb[tmp_id] = entity_data
if end == -1:
break
start = end
end += count
percent += 1
if args.save_emb is not None:
if not os.path.exists(args.save_emb):
os.mkdir(args.save_emb)
model_test.save_emb(args.save_emb, args.dataset)
if args.test:
args.num_thread = 1
test_sampler_tails = []
test_sampler_heads = []
for i in range(args.num_test_proc):
test_sampler_head = eval_dataset.create_sampler('test', args.batch_size_eval,
args.neg_sample_size_test,
args.neg_chunk_size_test,
args.eval_filter,
mode='chunk-head',
num_workers=args.num_thread,
rank=i, ranks=args.num_test_proc)
test_sampler_tail = eval_dataset.create_sampler('test', args.batch_size_eval,
args.neg_sample_size_test,
args.neg_chunk_size_test,
args.eval_filter,
mode='chunk-tail',
num_workers=args.num_thread,
rank=i, ranks=args.num_test_proc)
test_sampler_heads.append(test_sampler_head)
test_sampler_tails.append(test_sampler_tail)
eval_dataset = None
dataset_full = None
print("Run test, test processes: %d" % args.num_test_proc)
queue = mp.Queue(args.num_test_proc)
procs = []
for i in range(args.num_test_proc):
proc = mp.Process(target=test_mp, args=(args,
model_test,
[test_sampler_heads[i], test_sampler_tails[i]],
i,
'Test',
queue))
procs.append(proc)
proc.start()
total_metrics = {}
metrics = {}
logs = []
for i in range(args.num_test_proc):
log = queue.get()
logs = logs + log
for metric in logs[0].keys():
metrics[metric] = sum([log[metric] for log in logs]) / len(logs)
for k, v in metrics.items():
print('Test average {} at [{}/{}]: {}'.format(k, args.step, args.max_step, v))
for proc in procs:
proc.join()
if client.get_id() == 0:
client.shut_down()
\ No newline at end of file
...@@ -72,7 +72,7 @@ def read_ip_config(filename): ...@@ -72,7 +72,7 @@ def read_ip_config(filename):
server_id += 1 server_id += 1
machine_id += 1 machine_id += 1
except: except:
print("Error: data format on each line should be: [machine_id] [ip] [base_port] [server_count]") print("Error: data format on each line should be: [ip] [base_port] [server_count]")
return server_namebook return server_namebook
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment