"git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "bfdd1eaa446bd58ec35dbb54e247abed11c70084"
Unverified Commit f8ae6350 authored by Chao Ma's avatar Chao Ma Committed by GitHub
Browse files

[DGL-KE] Add METIS preprocessing pipeline (#1365)

* update metis

* update

* update dataloader

* update dataloader

* new script

* update

* update

* update

* update

* update

* update

* update

* update dataloader

* update

* update

* update

* update

* update
parent 2ce0e21b
...@@ -44,6 +44,12 @@ def _parse_srd_format(format): ...@@ -44,6 +44,12 @@ def _parse_srd_format(format):
if format == "trh": if format == "trh":
return [2, 1, 0] return [2, 1, 0]
def _file_line(path):
with open(path) as f:
for i, l in enumerate(f):
pass
return i + 1
class KGDataset: class KGDataset:
'''Load a knowledge graph '''Load a knowledge graph
...@@ -58,18 +64,16 @@ class KGDataset: ...@@ -58,18 +64,16 @@ class KGDataset:
The triples are stored as 'head_name\trelation_name\ttail_name'. The triples are stored as 'head_name\trelation_name\ttail_name'.
''' '''
def __init__(self, entity_path, relation_path, def __init__(self, entity_path, relation_path, train_path,
train_path, valid_path=None, test_path=None, valid_path=None, test_path=None, format=[0,1,2], skip_first_line=False):
format=[0,1,2], read_triple=True, only_train=False,
skip_first_line=False):
self.entity2id, self.n_entities = self.read_entity(entity_path) self.entity2id, self.n_entities = self.read_entity(entity_path)
self.relation2id, self.n_relations = self.read_relation(relation_path) self.relation2id, self.n_relations = self.read_relation(relation_path)
self.train = self.read_triple(train_path, "train", skip_first_line, format)
if read_triple == True: if valid_path is not None:
self.train = self.read_triple(train_path, "train", skip_first_line, format) self.valid = self.read_triple(valid_path, "valid", skip_first_line, format)
if only_train == False: if test_path is not None:
self.valid = self.read_triple(valid_path, "valid", skip_first_line, format) self.test = self.read_triple(test_path, "test", skip_first_line, format)
self.test = self.read_triple(test_path, "test", skip_first_line, format)
def read_entity(self, entity_path): def read_entity(self, entity_path):
with open(entity_path) as f: with open(entity_path) as f:
...@@ -94,6 +98,7 @@ class KGDataset: ...@@ -94,6 +98,7 @@ class KGDataset:
if path is None: if path is None:
return None return None
print('Reading {} triples....'.format(mode))
heads = [] heads = []
tails = [] tails = []
rels = [] rels = []
...@@ -106,12 +111,56 @@ class KGDataset: ...@@ -106,12 +111,56 @@ class KGDataset:
heads.append(self.entity2id[h]) heads.append(self.entity2id[h])
rels.append(self.relation2id[r]) rels.append(self.relation2id[r])
tails.append(self.entity2id[t]) tails.append(self.entity2id[t])
heads = np.array(heads, dtype=np.int64) heads = np.array(heads, dtype=np.int64)
tails = np.array(tails, dtype=np.int64) tails = np.array(tails, dtype=np.int64)
rels = np.array(rels, dtype=np.int64) rels = np.array(rels, dtype=np.int64)
print('Finished. Read {} {} triples.'.format(len(heads), mode))
return (heads, rels, tails) return (heads, rels, tails)
class PartitionKGDataset():
'''Load a partitioned knowledge graph
The folder with a partitioned knowledge graph has four files:
* relations stores the mapping between relation Id and relation name.
* train stores the triples in the training set.
* local_to_global stores the mapping of local id and global id
* partition_book stores the machine id of each entity
The triples are stored as 'head_id\relation_id\tail_id'.
'''
def __init__(self, relation_path, train_path, local2global_path,
read_triple=True, skip_first_line=False):
self.n_entities = _file_line(local2global_path)
if skip_first_line == False:
self.n_relations = _file_line(relation_path)
else:
self.n_relations = _file_line(relation_path) - 1
if read_triple == True:
self.train = self.read_triple(train_path, "train")
def read_triple(self, path, mode):
heads = []
tails = []
rels = []
print('Reading {} triples....'.format(mode))
with open(path) as f:
for line in f:
h, r, t = line.strip().split('\t')
heads.append(int(h))
rels.append(int(r))
tails.append(int(t))
heads = np.array(heads, dtype=np.int64)
tails = np.array(tails, dtype=np.int64)
rels = np.array(rels, dtype=np.int64)
print('Finished. Read {} {} triples.'.format(len(heads), mode))
return (heads, rels, tails)
class KGDatasetFB15k(KGDataset): class KGDatasetFB15k(KGDataset):
'''Load a knowledge graph FB15k '''Load a knowledge graph FB15k
...@@ -125,7 +174,7 @@ class KGDatasetFB15k(KGDataset): ...@@ -125,7 +174,7 @@ class KGDatasetFB15k(KGDataset):
The mapping between entity (relation) name and entity (relation) Id is stored as 'name\tid'. The mapping between entity (relation) name and entity (relation) Id is stored as 'name\tid'.
The triples are stored as 'head_nid\trelation_id\ttail_nid'. The triples are stored as 'head_nid\trelation_id\ttail_nid'.
''' '''
def __init__(self, path, name='FB15k', read_triple=True, only_train=False): def __init__(self, path, name='FB15k'):
self.name = name self.name = name
url = 'https://data.dgl.ai/dataset/{}.zip'.format(name) url = 'https://data.dgl.ai/dataset/{}.zip'.format(name)
...@@ -138,9 +187,8 @@ class KGDatasetFB15k(KGDataset): ...@@ -138,9 +187,8 @@ class KGDatasetFB15k(KGDataset):
os.path.join(self.path, 'relations.dict'), os.path.join(self.path, 'relations.dict'),
os.path.join(self.path, 'train.txt'), os.path.join(self.path, 'train.txt'),
os.path.join(self.path, 'valid.txt'), os.path.join(self.path, 'valid.txt'),
os.path.join(self.path, 'test.txt'), os.path.join(self.path, 'test.txt'))
read_triple=read_triple,
only_train=only_train)
class KGDatasetFB15k237(KGDataset): class KGDatasetFB15k237(KGDataset):
'''Load a knowledge graph FB15k-237 '''Load a knowledge graph FB15k-237
...@@ -155,7 +203,7 @@ class KGDatasetFB15k237(KGDataset): ...@@ -155,7 +203,7 @@ class KGDatasetFB15k237(KGDataset):
The mapping between entity (relation) name and entity (relation) Id is stored as 'name\tid'. The mapping between entity (relation) name and entity (relation) Id is stored as 'name\tid'.
The triples are stored as 'head_nid\trelation_id\ttail_nid'. The triples are stored as 'head_nid\trelation_id\ttail_nid'.
''' '''
def __init__(self, path, name='FB15k-237', read_triple=True, only_train=False): def __init__(self, path, name='FB15k-237'):
self.name = name self.name = name
url = 'https://data.dgl.ai/dataset/{}.zip'.format(name) url = 'https://data.dgl.ai/dataset/{}.zip'.format(name)
...@@ -168,9 +216,8 @@ class KGDatasetFB15k237(KGDataset): ...@@ -168,9 +216,8 @@ class KGDatasetFB15k237(KGDataset):
os.path.join(self.path, 'relations.dict'), os.path.join(self.path, 'relations.dict'),
os.path.join(self.path, 'train.txt'), os.path.join(self.path, 'train.txt'),
os.path.join(self.path, 'valid.txt'), os.path.join(self.path, 'valid.txt'),
os.path.join(self.path, 'test.txt'), os.path.join(self.path, 'test.txt'))
read_triple=read_triple,
only_train=only_train)
class KGDatasetWN18(KGDataset): class KGDatasetWN18(KGDataset):
'''Load a knowledge graph wn18 '''Load a knowledge graph wn18
...@@ -185,7 +232,7 @@ class KGDatasetWN18(KGDataset): ...@@ -185,7 +232,7 @@ class KGDatasetWN18(KGDataset):
The mapping between entity (relation) name and entity (relation) Id is stored as 'name\tid'. The mapping between entity (relation) name and entity (relation) Id is stored as 'name\tid'.
The triples are stored as 'head_nid\trelation_id\ttail_nid'. The triples are stored as 'head_nid\trelation_id\ttail_nid'.
''' '''
def __init__(self, path, name='wn18', read_triple=True, only_train=False): def __init__(self, path, name='wn18'):
self.name = name self.name = name
url = 'https://data.dgl.ai/dataset/{}.zip'.format(name) url = 'https://data.dgl.ai/dataset/{}.zip'.format(name)
...@@ -198,9 +245,8 @@ class KGDatasetWN18(KGDataset): ...@@ -198,9 +245,8 @@ class KGDatasetWN18(KGDataset):
os.path.join(self.path, 'relations.dict'), os.path.join(self.path, 'relations.dict'),
os.path.join(self.path, 'train.txt'), os.path.join(self.path, 'train.txt'),
os.path.join(self.path, 'valid.txt'), os.path.join(self.path, 'valid.txt'),
os.path.join(self.path, 'test.txt'), os.path.join(self.path, 'test.txt'))
read_triple=read_triple,
only_train=only_train)
class KGDatasetWN18rr(KGDataset): class KGDatasetWN18rr(KGDataset):
'''Load a knowledge graph wn18rr '''Load a knowledge graph wn18rr
...@@ -215,7 +261,7 @@ class KGDatasetWN18rr(KGDataset): ...@@ -215,7 +261,7 @@ class KGDatasetWN18rr(KGDataset):
The mapping between entity (relation) name and entity (relation) Id is stored as 'name\tid'. The mapping between entity (relation) name and entity (relation) Id is stored as 'name\tid'.
The triples are stored as 'head_nid\trelation_id\ttail_nid'. The triples are stored as 'head_nid\trelation_id\ttail_nid'.
''' '''
def __init__(self, path, name='wn18rr', read_triple=True, only_train=False): def __init__(self, path, name='wn18rr'):
self.name = name self.name = name
url = 'https://data.dgl.ai/dataset/{}.zip'.format(name) url = 'https://data.dgl.ai/dataset/{}.zip'.format(name)
...@@ -228,9 +274,7 @@ class KGDatasetWN18rr(KGDataset): ...@@ -228,9 +274,7 @@ class KGDatasetWN18rr(KGDataset):
os.path.join(self.path, 'relations.dict'), os.path.join(self.path, 'relations.dict'),
os.path.join(self.path, 'train.txt'), os.path.join(self.path, 'train.txt'),
os.path.join(self.path, 'valid.txt'), os.path.join(self.path, 'valid.txt'),
os.path.join(self.path, 'test.txt'), os.path.join(self.path, 'test.txt'))
read_triple=read_triple,
only_train=only_train)
class KGDatasetFreebase(KGDataset): class KGDatasetFreebase(KGDataset):
'''Load a knowledge graph Full Freebase '''Load a knowledge graph Full Freebase
...@@ -245,7 +289,7 @@ class KGDatasetFreebase(KGDataset): ...@@ -245,7 +289,7 @@ class KGDatasetFreebase(KGDataset):
The mapping between entity (relation) name and entity (relation) Id is stored as 'name\tid'. The mapping between entity (relation) name and entity (relation) Id is stored as 'name\tid'.
The triples are stored as 'head_nid\trelation_id\ttail_nid'. The triples are stored as 'head_nid\trelation_id\ttail_nid'.
''' '''
def __init__(self, path, name='Freebase', read_triple=True, only_train=False): def __init__(self, path, name='Freebase'):
self.name = name self.name = name
url = 'https://data.dgl.ai/dataset/{}.zip'.format(name) url = 'https://data.dgl.ai/dataset/{}.zip'.format(name)
...@@ -258,9 +302,7 @@ class KGDatasetFreebase(KGDataset): ...@@ -258,9 +302,7 @@ class KGDatasetFreebase(KGDataset):
os.path.join(self.path, 'relation2id.txt'), os.path.join(self.path, 'relation2id.txt'),
os.path.join(self.path, 'train.txt'), os.path.join(self.path, 'train.txt'),
os.path.join(self.path, 'valid.txt'), os.path.join(self.path, 'valid.txt'),
os.path.join(self.path, 'test.txt'), os.path.join(self.path, 'test.txt'))
read_triple=read_triple,
only_train=only_train)
def read_entity(self, entity_path): def read_entity(self, entity_path):
with open(entity_path) as f_ent: with open(entity_path) as f_ent:
...@@ -285,6 +327,7 @@ class KGDatasetFreebase(KGDataset): ...@@ -285,6 +327,7 @@ class KGDatasetFreebase(KGDataset):
heads.append(int(h)) heads.append(int(h))
tails.append(int(t)) tails.append(int(t))
rels.append(int(r)) rels.append(int(r))
heads = np.array(heads, dtype=np.int64) heads = np.array(heads, dtype=np.int64)
tails = np.array(tails, dtype=np.int64) tails = np.array(tails, dtype=np.int64)
rels = np.array(rels, dtype=np.int64) rels = np.array(rels, dtype=np.int64)
...@@ -319,9 +362,7 @@ class KGDatasetUDDRaw(KGDataset): ...@@ -319,9 +362,7 @@ class KGDatasetUDDRaw(KGDataset):
super(KGDatasetUDDRaw, self).__init__("entities.tsv", super(KGDatasetUDDRaw, self).__init__("entities.tsv",
"relation.tsv", "relation.tsv",
os.path.join(path, files[0]), os.path.join(path, files[0]),
format=format, format=format)
read_triple=True,
only_train=True)
# Train, validation and test set are provided # Train, validation and test set are provided
if len(files) == 3: if len(files) == 3:
super(KGDatasetUDDRaw, self).__init__("entities.tsv", super(KGDatasetUDDRaw, self).__init__("entities.tsv",
...@@ -329,9 +370,7 @@ class KGDatasetUDDRaw(KGDataset): ...@@ -329,9 +370,7 @@ class KGDatasetUDDRaw(KGDataset):
os.path.join(path, files[0]), os.path.join(path, files[0]),
os.path.join(path, files[1]), os.path.join(path, files[1]),
os.path.join(path, files[2]), os.path.join(path, files[2]),
format=format, format=format)
read_triple=True,
only_train=False)
def load_entity_relation(self, path, files, format): def load_entity_relation(self, path, files, format):
entity_map = {} entity_map = {}
...@@ -376,7 +415,7 @@ class KGDatasetUDD(KGDataset): ...@@ -376,7 +415,7 @@ class KGDatasetUDD(KGDataset):
The mapping between entity (relation) name and entity (relation) Id is stored as 'name\tid'. The mapping between entity (relation) name and entity (relation) Id is stored as 'name\tid'.
The triples are stored as 'head_nid\trelation_id\ttail_nid'. The triples are stored as 'head_nid\trelation_id\ttail_nid'.
''' '''
def __init__(self, path, name, files, format, read_triple=True, only_train=False): def __init__(self, path, name, files, format):
self.name = name self.name = name
for f in files: for f in files:
assert os.path.exists(os.path.join(path, f)), \ assert os.path.exists(os.path.join(path, f)), \
...@@ -388,18 +427,14 @@ class KGDatasetUDD(KGDataset): ...@@ -388,18 +427,14 @@ class KGDatasetUDD(KGDataset):
os.path.join(path, files[1]), os.path.join(path, files[1]),
os.path.join(path, files[2]), os.path.join(path, files[2]),
None, None, None, None,
format=format, format=format)
read_triple=read_triple,
only_train=only_train)
if len(files) == 5: if len(files) == 5:
super(KGDatasetUDD, self).__init__(os.path.join(path, files[0]), super(KGDatasetUDD, self).__init__(os.path.join(path, files[0]),
os.path.join(path, files[1]), os.path.join(path, files[1]),
os.path.join(path, files[2]), os.path.join(path, files[2]),
os.path.join(path, files[3]), os.path.join(path, files[3]),
os.path.join(path, files[4]), os.path.join(path, files[4]),
format=format, format=format)
read_triple=read_triple,
only_train=only_train)
def read_entity(self, entity_path): def read_entity(self, entity_path):
n_entities = 0 n_entities = 0
...@@ -463,77 +498,72 @@ def get_dataset(data_path, data_name, format_str, files=None): ...@@ -463,77 +498,72 @@ def get_dataset(data_path, data_name, format_str, files=None):
return dataset return dataset
def get_partition_dataset(data_path, data_name, format_str, part_id): def get_partition_dataset(data_path, data_name, part_id):
part_name = os.path.join(data_name, 'part_'+str(part_id)) part_name = os.path.join(data_name, 'partition_'+str(part_id))
if format_str == 'built_in': path = os.path.join(data_path, part_name)
if data_name == 'Freebase':
dataset = KGDatasetFreebase(data_path, part_name, read_triple=True, only_train=True) if not os.path.exists(path):
elif data_name == 'FB15k': print('Partition file not found.')
dataset = KGDatasetFB15k(data_path, part_name, read_triple=True, only_train=True) exit()
elif data_name == 'FB15k-237':
dataset = KGDatasetFB15k237(data_path, part_name, read_triple=True, only_train=True) train_path = os.path.join(path, 'train.txt')
elif data_name == 'wn18': local2global_path = os.path.join(path, 'local_to_global.txt')
dataset = KGDatasetWN18(data_path, part_name, read_triple=True, only_train=True) partition_book_path = os.path.join(path, 'partition_book.txt')
elif data_name == 'wn18rr':
dataset = KGDatasetWN18rr(data_path, part_name, read_triple=True, only_train=True) if data_name == 'Freebase':
else: relation_path = os.path.join(path, 'relation2id.txt')
assert False, "Unknown dataset {}".format(data_name) skip_first_line = True
elif format_str == 'raw_udd':
# user defined dataset
assert False, "When using partitioned dataset, we assume dataset will not be raw"
elif format_str == 'udd':
# user defined dataset
format = format_str[4:]
dataset = KGDatasetUDD(data_path, data_name, files, format, read_triple=True, only_train=True)
else: else:
assert False, "Unknown format {}".format(format_str) relation_path = os.path.join(path, 'relations.dict')
skip_first_line = False
path = os.path.join(data_path, part_name) dataset = PartitionKGDataset(relation_path,
train_path,
local2global_path,
read_triple=True,
skip_first_line=skip_first_line)
partition_book = [] partition_book = []
with open(os.path.join(path, 'partition_book.txt')) as f: with open(partition_book_path) as f:
for line in f: for line in f:
partition_book.append(int(line)) partition_book.append(int(line))
local_to_global = [] local_to_global = []
with open(os.path.join(path, 'local_to_global.txt')) as f: with open(local2global_path) as f:
for line in f: for line in f:
local_to_global.append(int(line)) local_to_global.append(int(line))
return dataset, partition_book, local_to_global return dataset, partition_book, local_to_global
def get_server_partition_dataset(data_path, data_name, format_str, part_id):
part_name = os.path.join(data_name, 'part_'+str(part_id))
if format_str == 'built_in': def get_server_partition_dataset(data_path, data_name, part_id):
if data_name == 'Freebase': part_name = os.path.join(data_name, 'partition_'+str(part_id))
dataset = KGDatasetFreebase(data_path, part_name, read_triple=False, only_train=True) path = os.path.join(data_path, part_name)
elif data_name == 'FB15k':
dataset = KGDatasetFB15k(data_path, part_name, read_triple=False, only_train=True) if not os.path.exists(path):
elif data_name == 'FB15k-237': print('Partition file not found.')
dataset = KGDatasetFB15k237(data_path, part_name, read_triple=False, only_train=True) exit()
elif data_name == 'wn18':
dataset = KGDatasetWN18(data_path, part_name, read_triple=False, only_train=True) train_path = os.path.join(path, 'train.txt')
elif data_name == 'wn18rr': local2global_path = os.path.join(path, 'local_to_global.txt')
dataset = KGDatasetWN18rr(data_path, part_name, read_triple=False, only_train=True)
else: if data_name == 'Freebase':
assert False, "Unknown dataset {}".format(data_name) relation_path = os.path.join(path, 'relation2id.txt')
elif format_str == 'raw_udd': skip_first_line = True
# user defined dataset
assert False, "When using partitioned dataset, we assume dataset will not be raw"
elif format_str == 'udd':
# user defined dataset
format = format_str[4:]
dataset = KGDatasetUDD(data_path, data_name, files, format, read_triple=False, only_train=True)
else: else:
assert False, "Unknown format {}".format(format_str) relation_path = os.path.join(path, 'relations.dict')
skip_first_line = False
path = os.path.join(data_path, part_name) dataset = PartitionKGDataset(relation_path,
train_path,
local2global_path,
read_triple=False,
skip_first_line=skip_first_line)
n_entities = len(open(os.path.join(path, 'partition_book.txt')).readlines()) n_entities = _file_line(os.path.join(path, 'partition_book.txt'))
local_to_global = [] local_to_global = []
with open(os.path.join(path, 'local_to_global.txt')) as f: with open(local2global_path) as f:
for line in f: for line in f:
local_to_global.append(int(line)) local_to_global.append(int(line))
......
## Training Scripts for distributed training
1. Partition data
Partition FB15k:
```bash
./partition.sh ../data FB15k 4
```
Partition Freebase:
```bash
./partition.sh ../data Freebase 4
```
2. Modify `ip_config.txt` and copy dgl-ke to all the machines
3. Run
```bash
./launch.sh \
~/dgl/apps/kg/distributed \
./fb15k_transe_l2.sh \
ubuntu ~/mykey.pem
```
\ No newline at end of file
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
################################################################################## ##################################################################################
machine_id=$1 machine_id=$1
server_count=$2 server_count=$2
machine_count=$3
# Delete the temp file # Delete the temp file
rm *-shape rm *-shape
...@@ -26,4 +27,4 @@ done ...@@ -26,4 +27,4 @@ done
################################################################################## ##################################################################################
MKL_NUM_THREADS=1 OMP_NUM_THREADS=1 DGLBACKEND=pytorch python3 ../kvclient.py --model TransE_l2 --dataset FB15k \ MKL_NUM_THREADS=1 OMP_NUM_THREADS=1 DGLBACKEND=pytorch python3 ../kvclient.py --model TransE_l2 --dataset FB15k \
--batch_size 1000 --neg_sample_size 200 --hidden_dim 400 --gamma 19.9 --lr 0.25 --max_step 500 --log_interval 100 --num_thread 1 \ --batch_size 1000 --neg_sample_size 200 --hidden_dim 400 --gamma 19.9 --lr 0.25 --max_step 500 --log_interval 100 --num_thread 1 \
--batch_size_eval 16 --test -adv --regularization_coef 1e-9 --total_machine 4 --num_client 16 --batch_size_eval 16 --test -adv --regularization_coef 1e-9 --total_machine $machine_count --num_client 16
\ No newline at end of file \ No newline at end of file
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
################################################################################## ##################################################################################
machine_id=$1 machine_id=$1
server_count=$2 server_count=$2
machine_count=$3
# Delete the temp file # Delete the temp file
rm *-shape rm *-shape
...@@ -26,4 +27,4 @@ done ...@@ -26,4 +27,4 @@ done
################################################################################## ##################################################################################
MKL_NUM_THREADS=1 OMP_NUM_THREADS=1 DGLBACKEND=pytorch python3 ../kvclient.py --model ComplEx --dataset Freebase \ MKL_NUM_THREADS=1 OMP_NUM_THREADS=1 DGLBACKEND=pytorch python3 ../kvclient.py --model ComplEx --dataset Freebase \
--batch_size 1024 --neg_sample_size 256 --hidden_dim 400 --gamma 143.0 --lr 0.1 --max_step 12500 --log_interval 100 \ --batch_size 1024 --neg_sample_size 256 --hidden_dim 400 --gamma 143.0 --lr 0.1 --max_step 12500 --log_interval 100 \
--batch_size_eval 1000 --neg_sample_size_eval 1000 --test -adv --total_machine 4 --num_thread 1 --num_client 40 --batch_size_eval 1000 --neg_sample_size_eval 1000 --test -adv --total_machine $machine_count --num_thread 1 --num_client 40
\ No newline at end of file
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
################################################################################## ##################################################################################
machine_id=$1 machine_id=$1
server_count=$2 server_count=$2
machine_count=$3
# Delete the temp file # Delete the temp file
rm *-shape rm *-shape
...@@ -26,4 +27,4 @@ done ...@@ -26,4 +27,4 @@ done
################################################################################## ##################################################################################
MKL_NUM_THREADS=1 OMP_NUM_THREADS=1 DGLBACKEND=pytorch python3 ../kvclient.py --model DistMult --dataset Freebase \ MKL_NUM_THREADS=1 OMP_NUM_THREADS=1 DGLBACKEND=pytorch python3 ../kvclient.py --model DistMult --dataset Freebase \
--batch_size 1024 --neg_sample_size 256 --hidden_dim 400 --gamma 143.0 --lr 0.08 --max_step 12500 --log_interval 100 \ --batch_size 1024 --neg_sample_size 256 --hidden_dim 400 --gamma 143.0 --lr 0.08 --max_step 12500 --log_interval 100 \
--batch_size_eval 1000 --neg_sample_size_eval 1000 --test -adv --total_machine 4 --num_thread 1 --num_client 40 --batch_size_eval 1000 --neg_sample_size_eval 1000 --test -adv --total_machine $machine_count --num_thread 1 --num_client 40
\ No newline at end of file
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
################################################################################## ##################################################################################
machine_id=$1 machine_id=$1
server_count=$2 server_count=$2
machine_count=$3
# Delete the temp file # Delete the temp file
rm *-shape rm *-shape
...@@ -26,4 +27,4 @@ done ...@@ -26,4 +27,4 @@ done
################################################################################## ##################################################################################
MKL_NUM_THREADS=1 OMP_NUM_THREADS=1 DGLBACKEND=pytorch python3 ../kvclient.py --model TransE_l2 --dataset Freebase \ MKL_NUM_THREADS=1 OMP_NUM_THREADS=1 DGLBACKEND=pytorch python3 ../kvclient.py --model TransE_l2 --dataset Freebase \
--batch_size 1000 --neg_sample_size 200 --hidden_dim 400 --gamma 10 --lr 0.1 --max_step 12500 --log_interval 100 --num_thread 1 \ --batch_size 1000 --neg_sample_size 200 --hidden_dim 400 --gamma 10 --lr 0.1 --max_step 12500 --log_interval 100 --num_thread 1 \
--batch_size_eval 1000 --neg_sample_size_eval 1000 --test -adv --regularization_coef 1e-9 --total_machine 4 --num_client 40 --batch_size_eval 1000 --neg_sample_size_eval 1000 --test -adv --regularization_coef 1e-9 --total_machine $machine_count --num_client 40
\ No newline at end of file
################################################################################## ##################################################################################
# User runs this script to launch distrobited jobs on cluster # User runs this script to launch distrobited jobs on cluster
################################################################################## ##################################################################################
script_path=~/dgl/apps/kg/distributed script_path=$1
script_file=./fb15k_transe_l2.sh script_file=$2
user_name=ubuntu user_name=$3
ssh_key=~/mctt.pem ssh_key=$4
server_count=$(awk 'NR==1 {print $3}' ip_config.txt) server_count=$(awk 'NR==1 {print $3}' ip_config.txt)
machine_count=$(awk 'END{print NR}' ip_config.txt)
# run command on remote machine # run command on remote machine
LINE_LOW=2 LINE_LOW=2
...@@ -18,8 +19,14 @@ do ...@@ -18,8 +19,14 @@ do
ip=$(awk 'NR=='$LINE_LOW' {print $1}' ip_config.txt) ip=$(awk 'NR=='$LINE_LOW' {print $1}' ip_config.txt)
let LINE_LOW+=1 let LINE_LOW+=1
let s_id+=1 let s_id+=1
ssh -i $ssh_key $user_name@$ip 'cd '$script_path'; '$script_file' '$s_id' '$server_count' ' & if test -z "$ssh_key"
then
ssh $user_name@$ip 'cd '$script_path'; '$script_file' '$s_id' '$server_count' '$machine_count'' &
else
ssh -i $ssh_key $user_name@$ip 'cd '$script_path'; '$script_file' '$s_id' '$server_count' '$machine_count'' &
fi
done done
# run command on local machine # run command on local machine
$script_file 0 $server_count $script_file 0 $server_count $machine_count
\ No newline at end of file \ No newline at end of file
##################################################################################
# User runs this script to partition a graph using METIS
##################################################################################
DATA_PATH=$1
DATA_SET=$2
K=$3
# partition graph
python3 ../partition.py --dataset $DATA_SET -k $K --data_path $DATA_PATH
# copy related file to partition
PART_ID=0
while [ $PART_ID -lt $K ]
do
cp $DATA_PATH/$DATA_SET/relation* $DATA_PATH/$DATA_SET/partition_$PART_ID
let PART_ID+=1
done
\ No newline at end of file
...@@ -180,7 +180,6 @@ def start_worker(args, logger): ...@@ -180,7 +180,6 @@ def start_worker(args, logger):
dataset, entity_partition_book, local2global = get_partition_dataset( dataset, entity_partition_book, local2global = get_partition_dataset(
args.data_path, args.data_path,
args.dataset, args.dataset,
args.format,
args.machine_id) args.machine_id)
n_entities = dataset.n_entities n_entities = dataset.n_entities
......
...@@ -87,7 +87,6 @@ def get_server_data(args, machine_id): ...@@ -87,7 +87,6 @@ def get_server_data(args, machine_id):
g2l, dataset = get_server_partition_dataset( g2l, dataset = get_server_partition_dataset(
args.data_path, args.data_path,
args.dataset, args.dataset,
args.format,
machine_id) machine_id)
# Note that the dataset doesn't ccontain the triple # Note that the dataset doesn't ccontain the triple
......
...@@ -2,11 +2,51 @@ from dataloader import get_dataset ...@@ -2,11 +2,51 @@ from dataloader import get_dataset
import scipy as sp import scipy as sp
import numpy as np import numpy as np
import argparse import argparse
import signal import os
import dgl import dgl
from dgl import backend as F from dgl import backend as F
from dgl.data.utils import load_graphs, save_graphs from dgl.data.utils import load_graphs, save_graphs
def write_txt_graph(path, file_name, part_dict, total_nodes):
partition_book = [0] * total_nodes
for part_id in part_dict:
print('write graph %d...' % part_id)
# Get (h,r,t) triples
partition_path = path + str(part_id)
if not os.path.exists(partition_path):
os.mkdir(partition_path)
triple_file = os.path.join(partition_path, file_name)
f = open(triple_file, 'w')
graph = part_dict[part_id]
src, dst = graph.all_edges(form='uv', order='eid')
rel = graph.edata['tid']
assert len(src) == len(rel)
src = F.asnumpy(src)
dst = F.asnumpy(dst)
rel = F.asnumpy(rel)
for i in range(len(src)):
f.write(str(src[i])+'\t'+str(rel[i])+'\t'+str(dst[i])+'\n')
f.close()
# Get local2global
l2g_file = os.path.join(partition_path, 'local_to_global.txt')
f = open(l2g_file, 'w')
pid = F.asnumpy(graph.parent_nid)
for i in range(len(pid)):
f.write(str(pid[i])+'\n')
f.close()
# Update partition_book
partition = F.asnumpy(graph.ndata['part_id'])
for i in range(len(pid)):
partition_book[pid[i]] = partition[i]
# Write partition_book.txt
for part_id in part_dict:
partition_path = path + str(part_id)
pb_file = os.path.join(partition_path, 'partition_book.txt')
f = open(pb_file, 'w')
for i in range(len(partition_book)):
f.write(str(partition_book[i])+'\n')
f.close()
def main(): def main():
parser = argparse.ArgumentParser(description='Partition a knowledge graph') parser = argparse.ArgumentParser(description='Partition a knowledge graph')
parser.add_argument('--data_path', type=str, default='data', parser.add_argument('--data_path', type=str, default='data',
...@@ -23,15 +63,21 @@ def main(): ...@@ -23,15 +63,21 @@ def main():
args = parser.parse_args() args = parser.parse_args()
num_parts = args.num_parts num_parts = args.num_parts
print('load dataset..')
# load dataset and samplers # load dataset and samplers
dataset = get_dataset(args.data_path, args.dataset, args.format, args.data_files) dataset = get_dataset(args.data_path, args.dataset, args.format, args.data_files)
print('construct graph...')
src, etype_id, dst = dataset.train src, etype_id, dst = dataset.train
coo = sp.sparse.coo_matrix((np.ones(len(src)), (src, dst)), coo = sp.sparse.coo_matrix((np.ones(len(src)), (src, dst)),
shape=[dataset.n_entities, dataset.n_entities]) shape=[dataset.n_entities, dataset.n_entities])
g = dgl.DGLGraph(coo, readonly=True, sort_csr=True) g = dgl.DGLGraph(coo, readonly=True, multigraph=True, sort_csr=True)
g.edata['tid'] = F.tensor(etype_id, F.int64) g.edata['tid'] = F.tensor(etype_id, F.int64)
print('partition graph...')
part_dict = dgl.transform.metis_partition(g, num_parts, 1) part_dict = dgl.transform.metis_partition(g, num_parts, 1)
tot_num_inner_edges = 0 tot_num_inner_edges = 0
...@@ -46,9 +92,15 @@ def main(): ...@@ -46,9 +92,15 @@ def main():
tot_num_inner_edges += num_inner_edges tot_num_inner_edges += num_inner_edges
part.copy_from_parent() part.copy_from_parent()
save_graphs(args.data_path + '/part_' + str(part_id) + '.dgl', [part])
print('write graph to txt file...')
txt_file_graph = os.path.join(args.data_path, args.dataset)
txt_file_graph = os.path.join(txt_file_graph, 'partition_')
write_txt_graph(txt_file_graph, 'train.txt', part_dict, g.number_of_nodes())
print('there are {} edges in the graph and {} edge cuts for {} partitions.'.format( print('there are {} edges in the graph and {} edge cuts for {} partitions.'.format(
g.number_of_edges(), g.number_of_edges() - tot_num_inner_edges, len(part_dict))) g.number_of_edges(), g.number_of_edges() - tot_num_inner_edges, len(part_dict)))
if __name__ == '__main__': if __name__ == '__main__':
main() main()
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment