Unverified Commit f9ba3cdc authored by xiang song(charlie.song)'s avatar xiang song(charlie.song) Committed by GitHub
Browse files

[KGE] Kg loader (#1300)



* Now dataset accepts user define datasets

* UPdate README

* Fix eval

* Fix

* Fix Freebase

* Fix

* Fix

* upd

* upd

* Update README

* Update some docstrings.

* upd
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-8-26.us-east-2.compute.internal>
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-19-78.us-east-2.compute.internal>
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-60-78.ec2.internal>
parent 614acf2c
...@@ -42,9 +42,9 @@ We will support multi-GPU training and distributed training in a near future. ...@@ -42,9 +42,9 @@ We will support multi-GPU training and distributed training in a near future.
The package can run with both Pytorch and MXNet. For Pytorch, it works with Pytorch v1.2 or newer. The package can run with both Pytorch and MXNet. For Pytorch, it works with Pytorch v1.2 or newer.
For MXNet, it works with MXNet 1.5 or newer. For MXNet, it works with MXNet 1.5 or newer.
## Datasets ## Built-in Datasets
DGL-KE provides five knowledge graphs: DGL-KE provides five built-in knowledge graphs:
| Dataset | #nodes | #edges | #relations | | Dataset | #nodes | #edges | #relations |
|---------|--------|--------|------------| |---------|--------|--------|------------|
...@@ -129,24 +129,20 @@ when given (?, rel, tail). ...@@ -129,24 +129,20 @@ when given (?, rel, tail).
### Input formats: ### Input formats:
DGL-KE supports two knowledge graph input formats. A knowledge graph is stored DGL-KE supports two knowledge graph input formats for user defined dataset
using five files.
Format 1: - raw_udd_[h|r|t], raw user defined dataset. In this format, user only need to provide triples and let the dataloader generate and manipulate the id mapping. The dataloader will generate two files: entities.tsv for entity id mapping and relations.tsv for relation id mapping. The order of head, relation and tail entities are described in [h|r|t], for example, raw_udd_trh means the triples are stored in the order of tail, relation and head. It should contains three files:
- *train* stores the triples in the training set. In format of a triple, e.g., [src_name, rel_name, dst_name] and should follow the order specified in [h|r|t]
- entities.dict contains pairs of (entity Id, entity name). The number of rows is the number of entities (nodes). - *valid* stores the triples in the validation set. In format of a triple, e.g., [src_name, rel_name, dst_name] and should follow the order specified in [h|r|t]
- relations.dict contains pairs of (relation Id, relation name). The number of rows is the number of relations. - *test* stores the triples in the test set. In format of a triple, e.g., [src_name, rel_name, dst_name] and should follow the order specified in [h|r|t]
- train.txt stores edges in the training set. They are stored as triples of (head, rel, tail).
- valid.txt stores edges in the validation set. They are stored as triples of (head, rel, tail).
- test.txt stores edges in the test set. They are stored as triples of (head, rel, tail).
Format 2: Format 2:
- udd_[h|r|t], user defined dataset. In this format, user should provide the id mapping for entities and relations. The order of head, relation and tail entities are described in [h|r|t], for example, raw_udd_trh means the triples are stored in the order of tail, relation and head. It should contains five files:
- entity2id.txt contains pairs of (entity name, entity Id). The number of rows is the number of entities (nodes). - *entities* stores the mapping between entity name and entity Id
- relation2id.txt contains pairs of (relation name, relation Id). The number of rows is the number of relations. - *relations* stores the mapping between relation name relation Id
- train.txt stores edges in the training set. They are stored as triples of (head, tail, rel). - *train* stores the triples in the training set. In format of a triple, e.g., [src_id, rel_id, dst_id] and should follow the order specified in [h|r|t]
- valid.txt stores edges in the validation set. They are stored as a triple of (head, tail, rel). - *valid* stores the triples in the validation set. In format of a triple, e.g., [src_id, rel_id, dst_id] and should follow the order specified in [h|r|t]
- test.txt stores edges in the test set. They are stored as a triple of (head, tail, rel). - *test* stores the triples in the test set. In format of a triple, e.g., [src_id, rel_id, dst_id] and should follow the order specified in [h|r|t]
### Output formats: ### Output formats:
......
...@@ -23,63 +23,86 @@ def _download_and_extract(url, path, filename): ...@@ -23,63 +23,86 @@ def _download_and_extract(url, path, filename):
writer.write(chunk) writer.write(chunk)
print('Download finished. Unzipping the file...') print('Download finished. Unzipping the file...')
class KGDataset1: def _get_id(dict, key):
'''Load a knowledge graph with format 1 id = dict.get(key, None)
if id is None:
In this format, the folder with a knowledge graph has five files: id = len(dict)
* entities.dict stores the mapping between entity Id and entity name. dict[key] = id
* relations.dict stores the mapping between relation Id and relation name. return id
* train.txt stores the triples in the training set.
* valid.txt stores the triples in the validation set. def _parse_srd_format(format):
* test.txt stores the triples in the test set. if format == "hrt":
return [0, 1, 2]
if format == "htr":
return [0, 2, 1]
if format == "rht":
return [1, 0, 2]
if format == "rth":
return [2, 0, 1]
if format == "thr":
return [1, 2, 0]
if format == "trh":
return [2, 1, 0]
class KGDataset:
'''Load a knowledge graph
The folder with a knowledge graph has five files:
* entities stores the mapping between entity Id and entity name.
* relations stores the mapping between relation Id and relation name.
* train stores the triples in the training set.
* valid stores the triples in the validation set.
* test stores the triples in the test set.
The mapping between entity (relation) Id and entity (relation) name is stored as 'id\tname'. The mapping between entity (relation) Id and entity (relation) name is stored as 'id\tname'.
The triples are stored as 'head_name\trelation_name\ttail_name'. The triples are stored as 'head_name\trelation_name\ttail_name'.
''' '''
def __init__(self, path, name, read_triple=True, only_train=False): def __init__(self, entity_path, relation_path,
url = 'https://data.dgl.ai/dataset/{}.zip'.format(name) train_path, valid_path=None, test_path=None,
format=[0,1,2], read_triple=True, only_train=False,
skip_first_line=False):
self.entity2id, self.n_entities = self.read_entity(entity_path)
self.relation2id, self.n_relations = self.read_relation(relation_path)
if not os.path.exists(os.path.join(path, name)): if read_triple == True:
print('File not found. Downloading from', url) self.train = self.read_triple(train_path, "train", skip_first_line, format)
_download_and_extract(url, path, name + '.zip') if only_train == False:
path = os.path.join(path, name) self.valid = self.read_triple(valid_path, "valid", skip_first_line, format)
self.test = self.read_triple(test_path, "test", skip_first_line, format)
with open(os.path.join(path, 'entities.dict')) as f: def read_entity(self, entity_path):
with open(entity_path) as f:
entity2id = {} entity2id = {}
for line in f: for line in f:
eid, entity = line.strip().split('\t') eid, entity = line.strip().split('\t')
entity2id[entity] = int(eid) entity2id[entity] = int(eid)
self.entity2id = entity2id return entity2id, len(entity2id)
with open(os.path.join(path, 'relations.dict')) as f: def read_relation(self, relation_path):
with open(relation_path) as f:
relation2id = {} relation2id = {}
for line in f: for line in f:
rid, relation = line.strip().split('\t') rid, relation = line.strip().split('\t')
relation2id[relation] = int(rid) relation2id[relation] = int(rid)
self.relation2id = relation2id return relation2id, len(relation2id)
# TODO: to deal with contries dataset.
self.n_entities = len(self.entity2id)
self.n_relations = len(self.relation2id)
if read_triple == True:
self.train = self.read_triple(path, 'train')
if only_train == False:
self.valid = self.read_triple(path, 'valid')
self.test = self.read_triple(path, 'test')
def read_triple(self, path, mode): def read_triple(self, path, mode, skip_first_line=False, format=[0,1,2]):
# mode: train/valid/test # mode: train/valid/test
if path is None:
return None
heads = [] heads = []
tails = [] tails = []
rels = [] rels = []
with open(os.path.join(path, '{}.txt'.format(mode))) as f: with open(path) as f:
if skip_first_line:
_ = f.readline()
for line in f: for line in f:
h, r, t = line.strip().split('\t') triple = line.strip().split('\t')
h, r, t = triple[format[0]], triple[format[1]], triple[format[2]]
heads.append(self.entity2id[h]) heads.append(self.entity2id[h])
rels.append(self.relation2id[r]) rels.append(self.relation2id[r])
tails.append(self.entity2id[t]) tails.append(self.entity2id[t])
...@@ -89,11 +112,130 @@ class KGDataset1: ...@@ -89,11 +112,130 @@ class KGDataset1:
return (heads, rels, tails) return (heads, rels, tails)
class KGDatasetFB15k(KGDataset):
'''Load a knowledge graph FB15k
The FB15k dataset has five files:
* entities.dict stores the mapping between entity Id and entity name.
* relations.dict stores the mapping between relation Id and relation name.
* train.txt stores the triples in the training set.
* valid.txt stores the triples in the validation set.
* test.txt stores the triples in the test set.
class KGDataset2: The mapping between entity (relation) name and entity (relation) Id is stored as 'name\tid'.
'''Load a knowledge graph with format 2 The triples are stored as 'head_nid\trelation_id\ttail_nid'.
'''
def __init__(self, path, name='FB15k', read_triple=True, only_train=False):
self.name = name
url = 'https://data.dgl.ai/dataset/{}.zip'.format(name)
In this format, the folder with a knowledge graph has five files: if not os.path.exists(os.path.join(path, name)):
print('File not found. Downloading from', url)
_download_and_extract(url, path, name + '.zip')
self.path = os.path.join(path, name)
super(KGDatasetFB15k, self).__init__(os.path.join(self.path, 'entities.dict'),
os.path.join(self.path, 'relations.dict'),
os.path.join(self.path, 'train.txt'),
os.path.join(self.path, 'valid.txt'),
os.path.join(self.path, 'test.txt'),
read_triple=read_triple,
only_train=only_train)
class KGDatasetFB15k237(KGDataset):
'''Load a knowledge graph FB15k-237
The FB15k-237 dataset has five files:
* entities.dict stores the mapping between entity Id and entity name.
* relations.dict stores the mapping between relation Id and relation name.
* train.txt stores the triples in the training set.
* valid.txt stores the triples in the validation set.
* test.txt stores the triples in the test set.
The mapping between entity (relation) name and entity (relation) Id is stored as 'name\tid'.
The triples are stored as 'head_nid\trelation_id\ttail_nid'.
'''
def __init__(self, path, name='FB15k-237', read_triple=True, only_train=False):
self.name = name
url = 'https://data.dgl.ai/dataset/{}.zip'.format(name)
if not os.path.exists(os.path.join(path, name)):
print('File not found. Downloading from', url)
_download_and_extract(url, path, name + '.zip')
self.path = os.path.join(path, name)
super(KGDatasetFB15k237, self).__init__(os.path.join(self.path, 'entities.dict'),
os.path.join(self.path, 'relations.dict'),
os.path.join(self.path, 'train.txt'),
os.path.join(self.path, 'valid.txt'),
os.path.join(self.path, 'test.txt'),
read_triple=read_triple,
only_train=only_train)
class KGDatasetWN18(KGDataset):
'''Load a knowledge graph wn18
The wn18 dataset has five files:
* entities.dict stores the mapping between entity Id and entity name.
* relations.dict stores the mapping between relation Id and relation name.
* train.txt stores the triples in the training set.
* valid.txt stores the triples in the validation set.
* test.txt stores the triples in the test set.
The mapping between entity (relation) name and entity (relation) Id is stored as 'name\tid'.
The triples are stored as 'head_nid\trelation_id\ttail_nid'.
'''
def __init__(self, path, name='wn18', read_triple=True, only_train=False):
self.name = name
url = 'https://data.dgl.ai/dataset/{}.zip'.format(name)
if not os.path.exists(os.path.join(path, name)):
print('File not found. Downloading from', url)
_download_and_extract(url, path, name + '.zip')
self.path = os.path.join(path, name)
super(KGDatasetWN18, self).__init__(os.path.join(self.path, 'entities.dict'),
os.path.join(self.path, 'relations.dict'),
os.path.join(self.path, 'train.txt'),
os.path.join(self.path, 'valid.txt'),
os.path.join(self.path, 'test.txt'),
read_triple=read_triple,
only_train=only_train)
class KGDatasetWN18rr(KGDataset):
'''Load a knowledge graph wn18rr
The wn18rr dataset has five files:
* entities.dict stores the mapping between entity Id and entity name.
* relations.dict stores the mapping between relation Id and relation name.
* train.txt stores the triples in the training set.
* valid.txt stores the triples in the validation set.
* test.txt stores the triples in the test set.
The mapping between entity (relation) name and entity (relation) Id is stored as 'name\tid'.
The triples are stored as 'head_nid\trelation_id\ttail_nid'.
'''
def __init__(self, path, name='wn18rr', read_triple=True, only_train=False):
self.name = name
url = 'https://data.dgl.ai/dataset/{}.zip'.format(name)
if not os.path.exists(os.path.join(path, name)):
print('File not found. Downloading from', url)
_download_and_extract(url, path, name + '.zip')
self.path = os.path.join(path, name)
super(KGDatasetWN18rr, self).__init__(os.path.join(self.path, 'entities.dict'),
os.path.join(self.path, 'relations.dict'),
os.path.join(self.path, 'train.txt'),
os.path.join(self.path, 'valid.txt'),
os.path.join(self.path, 'test.txt'),
read_triple=read_triple,
only_train=only_train)
class KGDatasetFreebase(KGDataset):
'''Load a knowledge graph Full Freebase
The Freebase dataset has five files:
* entity2id.txt stores the mapping between entity name and entity Id. * entity2id.txt stores the mapping between entity name and entity Id.
* relation2id.txt stores the mapping between relation name relation Id. * relation2id.txt stores the mapping between relation name relation Id.
* train.txt stores the triples in the training set. * train.txt stores the triples in the training set.
...@@ -101,10 +243,10 @@ class KGDataset2: ...@@ -101,10 +243,10 @@ class KGDataset2:
* test.txt stores the triples in the test set. * test.txt stores the triples in the test set.
The mapping between entity (relation) name and entity (relation) Id is stored as 'name\tid'. The mapping between entity (relation) name and entity (relation) Id is stored as 'name\tid'.
The triples are stored as 'head_nid\trelation_id\ttail_nid'. The triples are stored as 'head_nid\trelation_id\ttail_nid'.
''' '''
def __init__(self, path, name, read_triple=True, only_train=False): def __init__(self, path, name='Freebase', read_triple=True, only_train=False):
self.name = name
url = 'https://data.dgl.ai/dataset/{}.zip'.format(name) url = 'https://data.dgl.ai/dataset/{}.zip'.format(name)
if not os.path.exists(os.path.join(path, name)): if not os.path.exists(os.path.join(path, name)):
...@@ -112,31 +254,30 @@ class KGDataset2: ...@@ -112,31 +254,30 @@ class KGDataset2:
_download_and_extract(url, path, '{}.zip'.format(name)) _download_and_extract(url, path, '{}.zip'.format(name))
self.path = os.path.join(path, name) self.path = os.path.join(path, name)
f_rel2id = os.path.join(self.path, 'relation2id.txt') super(KGDatasetFreebase, self).__init__(os.path.join(self.path, 'entity2id.txt'),
with open(f_rel2id) as f_rel: os.path.join(self.path, 'relation2id.txt'),
self.n_relations = int(f_rel.readline()[:-1]) os.path.join(self.path, 'train.txt'),
os.path.join(self.path, 'valid.txt'),
os.path.join(self.path, 'test.txt'),
read_triple=read_triple,
only_train=only_train)
if only_train == True: def read_entity(self, entity_path):
f_ent2id = os.path.join(self.path, 'local_to_global.txt') with open(entity_path) as f_ent:
with open(f_ent2id) as f_ent: n_entities = int(f_ent.readline()[:-1])
self.n_entities = len(f_ent.readlines()) return None, n_entities
else:
f_ent2id = os.path.join(self.path, 'entity2id.txt')
with open(f_ent2id) as f_ent:
self.n_entities = int(f_ent.readline()[:-1])
if read_triple == True: def read_relation(self, relation_path):
self.train = self.read_triple(self.path, 'train') with open(relation_path) as f_rel:
if only_train == False: n_relations = int(f_rel.readline()[:-1])
self.valid = self.read_triple(self.path, 'valid') return None, n_relations
self.test = self.read_triple(self.path, 'test')
def read_triple(self, path, mode, skip_first_line=False): def read_triple(self, path, mode, skip_first_line=False):
heads = [] heads = []
tails = [] tails = []
rels = [] rels = []
print('Reading {} triples....'.format(mode)) print('Reading {} triples....'.format(mode))
with open(os.path.join(path, '{}.txt'.format(mode))) as f: with open(path) as f:
if skip_first_line: if skip_first_line:
_ = f.readline() _ = f.readline()
for line in f: for line in f:
...@@ -150,27 +291,203 @@ class KGDataset2: ...@@ -150,27 +291,203 @@ class KGDataset2:
print('Finished. Read {} {} triples.'.format(len(heads), mode)) print('Finished. Read {} {} triples.'.format(len(heads), mode))
return (heads, rels, tails) return (heads, rels, tails)
class KGDatasetUDDRaw(KGDataset):
'''Load a knowledge graph user defined dataset
The user defined dataset has five files:
* entities stores the mapping between entity name and entity Id.
* relations stores the mapping between relation name relation Id.
* train stores the triples in the training set. In format [src_name, rel_name, dst_name]
* valid stores the triples in the validation set. In format [src_name, rel_name, dst_name]
* test stores the triples in the test set. In format [src_name, rel_name, dst_name]
The mapping between entity (relation) name and entity (relation) Id is stored as 'name\tid'.
The triples are stored as 'head_nid\trelation_id\ttail_nid'.
'''
def __init__(self, path, name, files, format):
self.name = name
for f in files:
assert os.path.exists(os.path.join(path, f)), \
'File {} now exist in {}'.format(f, path)
assert len(format) == 3
format = _parse_srd_format(format)
self.load_entity_relation(path, files, format)
# Only train set is provided
if len(files) == 1:
super(KGDatasetUDDRaw, self).__init__("entities.tsv",
"relation.tsv",
os.path.join(path, files[0]),
format=format,
read_triple=True,
only_train=True)
# Train, validation and test set are provided
if len(files) == 3:
super(KGDatasetUDDRaw, self).__init__("entities.tsv",
"relation.tsv",
os.path.join(path, files[0]),
os.path.join(path, files[1]),
os.path.join(path, files[2]),
format=format,
read_triple=True,
only_train=False)
def load_entity_relation(self, path, files, format):
entity_map = {}
rel_map = {}
for fi in files:
with open(os.path.join(path, fi)) as f:
for line in f:
triple = line.strip().split('\t')
src, rel, dst = triple[format[0]], triple[format[1]], triple[format[2]]
src_id = _get_id(entity_map, src)
dst_id = _get_id(entity_map, dst)
rel_id = _get_id(rel_map, rel)
entities = ["{}\t{}\n".format(key, val) for key, val in entity_map.items()]
with open(os.path.join(path, "entities.tsv"), "w+") as f:
f.writelines(entities)
self.entity2id = entity_map
self.n_entities = len(entities)
relations = ["{}\t{}\n".format(key, val) for key, val in rel_map.items()]
with open(os.path.join(path, "relations.tsv"), "w+") as f:
f.writelines(relations)
self.relation2id = rel_map
self.n_relations = len(relations)
def read_entity(self, entity_path):
return self.entity2id, self.n_entities
def read_relation(self, relation_path):
return self.relation2id, self.n_relations
class KGDatasetUDD(KGDataset):
'''Load a knowledge graph user defined dataset
The user defined dataset has five files:
* entities stores the mapping between entity name and entity Id.
* relations stores the mapping between relation name relation Id.
* train stores the triples in the training set. In format [src_id, rel_id, dst_id]
* valid stores the triples in the validation set. In format [src_id, rel_id, dst_id]
* test stores the triples in the test set. In format [src_id, rel_id, dst_id]
def get_dataset(data_path, data_name, format_str): The mapping between entity (relation) name and entity (relation) Id is stored as 'name\tid'.
The triples are stored as 'head_nid\trelation_id\ttail_nid'.
'''
def __init__(self, path, name, files, format, read_triple=True, only_train=False):
self.name = name
for f in files:
assert os.path.exists(os.path.join(path, f)), \
'File {} now exist in {}'.format(f, path)
format = _parse_srd_format(format)
if len(files) == 3:
super(KGDatasetUDD, self).__init__(os.path.join(path, files[0]),
os.path.join(path, files[1]),
os.path.join(path, files[2]),
os.path.join(path, None),
os.path.join(path, None),
format=format,
read_triple=read_triple,
only_train=only_train)
if len(files) == 5:
super(KGDatasetUDD, self).__init__(os.path.join(path, files[0]),
os.path.join(path, files[1]),
os.path.join(path, files[2]),
os.path.join(path, files[3]),
os.path.join(path, files[4]),
format=format,
read_triple=read_triple,
only_train=only_train)
def read_entity(self, entity_path):
n_entities = 0
with open(entity_path) as f_ent:
for line in f_ent:
n_entities += 1
return None, n_entities
def read_relation(self, relation_path):
n_relations = 0
with open(relation_path) as f_rel:
for line in f_rel:
n_relations += 1
return None, n_relations
def read_triple(self, path, mode, skip_first_line=False, format=[0,1,2]):
heads = []
tails = []
rels = []
print('Reading {} triples....'.format(mode))
with open(path) as f:
if skip_first_line:
_ = f.readline()
for line in f:
triple = line.strip().split('\t')
h, r, t = triple[format[0]], triple[format[1]], triple[format[2]]
heads.append(int(h))
tails.append(int(t))
rels.append(int(r))
heads = np.array(heads, dtype=np.int64)
tails = np.array(tails, dtype=np.int64)
rels = np.array(rels, dtype=np.int64)
print('Finished. Read {} {} triples.'.format(len(heads), mode))
return (heads, rels, tails)
def get_dataset(data_path, data_name, format_str, files=None):
if format_str == 'built_in':
if data_name == 'Freebase': if data_name == 'Freebase':
dataset = KGDataset2(data_path, data_name) dataset = KGDatasetFreebase(data_path)
elif format_str == '1': elif data_name == 'FB15k':
dataset = KGDataset1(data_path, data_name) dataset = KGDatasetFB15k(data_path)
elif data_name == 'FB15k-237':
dataset = KGDatasetFB15k237(data_path)
elif data_name == 'wn18':
dataset = KGDatasetWN18(data_path)
elif data_name == 'wn18rr':
dataset = KGDatasetWN18rr(data_path)
else: else:
dataset = KGDataset2(data_path, data_name) assert False, "Unknown dataset {}".format(data_name)
elif format_str.startswith('raw_udd'):
# user defined dataset
format = format_str[8:]
dataset = KGDatasetUDDRaw(data_path, data_name, files, format)
elif format_str.startswith('udd'):
# user defined dataset
format = format_str[4:]
dataset = KGDatasetUDD(data_path, data_name, files, format)
else:
assert False, "Unknown format {}".format(format_str)
return dataset return dataset
def get_partition_dataset(data_path, data_name, format_str, part_id): def get_partition_dataset(data_path, data_name, format_str, part_id):
part_name = os.path.join(data_name, 'part_'+str(part_id)) part_name = os.path.join(data_name, 'part_'+str(part_id))
if format_str == 'built_in':
if data_name == 'Freebase': if data_name == 'Freebase':
dataset = KGDataset2(data_path, part_name, read_triple=True, only_train=True) dataset = KGDatasetFreebase(data_path, part_name, read_triple=True, only_train=True)
elif format_str == '1': elif data_name == 'FB15k':
dataset = KGDataset1(data_path, part_name, read_triple=True, only_train=True) dataset = KGDatasetFB15k(data_path, part_name, read_triple=True, only_train=True)
elif data_name == 'FB15k-237':
dataset = KGDatasetFB15k237(data_path, part_name, read_triple=True, only_train=True)
elif data_name == 'wn18':
dataset = KGDatasetWN18(data_path, part_name, read_triple=True, only_train=True)
elif data_name == 'wn18rr':
dataset = KGDatasetWN18rr(data_path, part_name, read_triple=True, only_train=True)
else: else:
dataset = KGDataset2(data_path, part_name, read_triple=True, only_train=True) assert False, "Unknown dataset {}".format(data_name)
elif format_str == 'raw_udd':
# user defined dataset
assert False, "When using partitioned dataset, we assume dataset will not be raw"
elif format_str == 'udd':
# user defined dataset
format = format_str[4:]
dataset = KGDatasetUDD(data_path, data_name, files, format, read_triple=True, only_train=True)
else:
assert False, "Unknown format {}".format(format_str)
path = os.path.join(data_path, part_name) path = os.path.join(data_path, part_name)
...@@ -186,16 +503,31 @@ def get_partition_dataset(data_path, data_name, format_str, part_id): ...@@ -186,16 +503,31 @@ def get_partition_dataset(data_path, data_name, format_str, part_id):
return dataset, partition_book, local_to_global return dataset, partition_book, local_to_global
def get_server_partition_dataset(data_path, data_name, format_str, part_id): def get_server_partition_dataset(data_path, data_name, format_str, part_id):
part_name = os.path.join(data_name, 'part_'+str(part_id)) part_name = os.path.join(data_name, 'part_'+str(part_id))
if format_str == 'built_in':
if data_name == 'Freebase': if data_name == 'Freebase':
dataset = KGDataset2(data_path, part_name, read_triple=False, only_train=True) dataset = KGDatasetFreebase(data_path, part_name, read_triple=False, only_train=True)
elif format_str == '1': elif data_name == 'FB15k':
dataset = KGDataset1(data_path, part_name, read_triple=False, only_train=True) dataset = KGDatasetFB15k(data_path, part_name, read_triple=False, only_train=True)
elif data_name == 'FB15k-237':
dataset = KGDatasetFB15k237(data_path, part_name, read_triple=False, only_train=True)
elif data_name == 'wn18':
dataset = KGDatasetWN18(data_path, part_name, read_triple=False, only_train=True)
elif data_name == 'wn18rr':
dataset = KGDatasetWN18rr(data_path, part_name, read_triple=False, only_train=True)
else:
assert False, "Unknown dataset {}".format(data_name)
elif format_str == 'raw_udd':
# user defined dataset
assert False, "When using partitioned dataset, we assume dataset will not be raw"
elif format_str == 'udd':
# user defined dataset
format = format_str[4:]
dataset = KGDatasetUDD(data_path, data_name, files, format, read_triple=False, only_train=True)
else: else:
dataset = KGDataset2(data_path, part_name, read_triple=False, only_train=True) assert False, "Unknown format {}".format(format_str)
path = os.path.join(data_path, part_name) path = os.path.join(data_path, part_name)
......
...@@ -29,8 +29,11 @@ class ArgParser(argparse.ArgumentParser): ...@@ -29,8 +29,11 @@ class ArgParser(argparse.ArgumentParser):
help='root path of all dataset') help='root path of all dataset')
self.add_argument('--dataset', type=str, default='FB15k', self.add_argument('--dataset', type=str, default='FB15k',
help='dataset name, under data_path') help='dataset name, under data_path')
self.add_argument('--format', type=str, default='1', self.add_argument('--format', type=str, default='built_in',
help='the format of the dataset.') help='the format of the dataset, it can be built_in,'\
'raw_udd_{htr} and udd_{htr}')
self.add_argument('--data_files', type=str, default=None, nargs='+',
help='a list of data files, e.g. entity relation train valid test')
self.add_argument('--model_path', type=str, default='ckpts', self.add_argument('--model_path', type=str, default='ckpts',
help='the place where models are saved') help='the place where models are saved')
...@@ -97,7 +100,7 @@ def main(args): ...@@ -97,7 +100,7 @@ def main(args):
assert not args.eval_filter, "if negative sampling based on degree, we can't filter positive edges." assert not args.eval_filter, "if negative sampling based on degree, we can't filter positive edges."
# load dataset and samplers # load dataset and samplers
dataset = get_dataset(args.data_path, args.dataset, args.format) dataset = get_dataset(args.data_path, args.dataset, args.format, args.data_files)
args.pickle_graph = False args.pickle_graph = False
args.train = False args.train = False
args.valid = False args.valid = False
......
...@@ -31,8 +31,11 @@ class ArgParser(argparse.ArgumentParser): ...@@ -31,8 +31,11 @@ class ArgParser(argparse.ArgumentParser):
help='root path of all dataset') help='root path of all dataset')
self.add_argument('--dataset', type=str, default='FB15k', self.add_argument('--dataset', type=str, default='FB15k',
help='dataset name, under data_path') help='dataset name, under data_path')
self.add_argument('--format', type=str, default='1', self.add_argument('--format', type=str, default='built_in',
help='the format of the dataset.') help='the format of the dataset, it can be built_in,'\
'raw_udd_{htr} and udd_{htr}')
self.add_argument('--data_files', type=str, default=None, nargs='+',
help='a list of data files, e.g. entity relation train valid test')
self.add_argument('--save_path', type=str, default='ckpts', self.add_argument('--save_path', type=str, default='ckpts',
help='place to save models and logs') help='place to save models and logs')
self.add_argument('--save_emb', type=str, default=None, self.add_argument('--save_emb', type=str, default=None,
...@@ -156,7 +159,7 @@ def get_logger(args): ...@@ -156,7 +159,7 @@ def get_logger(args):
def run(args, logger): def run(args, logger):
train_time_start = time.time() train_time_start = time.time()
# load dataset and samplers # load dataset and samplers
dataset = get_dataset(args.data_path, args.dataset, args.format) dataset = get_dataset(args.data_path, args.dataset, args.format, args.data_files)
n_entities = dataset.n_entities n_entities = dataset.n_entities
n_relations = dataset.n_relations n_relations = dataset.n_relations
if args.neg_sample_size_test < 0: if args.neg_sample_size_test < 0:
......
...@@ -189,7 +189,6 @@ def test(args, model, test_samplers, rank=0, mode='Test', queue=None): ...@@ -189,7 +189,6 @@ def test(args, model, test_samplers, rank=0, mode='Test', queue=None):
with th.no_grad(): with th.no_grad():
logs = [] logs = []
for sampler in test_samplers: for sampler in test_samplers:
count = 0
for pos_g, neg_g in sampler: for pos_g, neg_g in sampler:
model.forward_test(pos_g, neg_g, logs, gpu_id) model.forward_test(pos_g, neg_g, logs, gpu_id)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment