Unverified Commit f9ba3cdc authored by xiang song(charlie.song)'s avatar xiang song(charlie.song) Committed by GitHub
Browse files

[KGE] Kg loader (#1300)



* Now dataset accepts user define datasets

* UPdate README

* Fix eval

* Fix

* Fix Freebase

* Fix

* Fix

* upd

* upd

* Update README

* Update some docstrings.

* upd
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-8-26.us-east-2.compute.internal>
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-19-78.us-east-2.compute.internal>
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-60-78.ec2.internal>
parent 614acf2c
......@@ -42,9 +42,9 @@ We will support multi-GPU training and distributed training in a near future.
The package can run with both Pytorch and MXNet. For Pytorch, it works with Pytorch v1.2 or newer.
For MXNet, it works with MXNet 1.5 or newer.
## Datasets
## Built-in Datasets
DGL-KE provides five knowledge graphs:
DGL-KE provides five built-in knowledge graphs:
| Dataset | #nodes | #edges | #relations |
|---------|--------|--------|------------|
......@@ -129,24 +129,20 @@ when given (?, rel, tail).
### Input formats:
DGL-KE supports two knowledge graph input formats. A knowledge graph is stored
using five files.
DGL-KE supports two knowledge graph input formats for user defined dataset
Format 1:
- entities.dict contains pairs of (entity Id, entity name). The number of rows is the number of entities (nodes).
- relations.dict contains pairs of (relation Id, relation name). The number of rows is the number of relations.
- train.txt stores edges in the training set. They are stored as triples of (head, rel, tail).
- valid.txt stores edges in the validation set. They are stored as triples of (head, rel, tail).
- test.txt stores edges in the test set. They are stored as triples of (head, rel, tail).
- raw_udd_[h|r|t], raw user defined dataset. In this format, user only need to provide triples and let the dataloader generate and manipulate the id mapping. The dataloader will generate two files: entities.tsv for entity id mapping and relations.tsv for relation id mapping. The order of head, relation and tail entities are described in [h|r|t], for example, raw_udd_trh means the triples are stored in the order of tail, relation and head. It should contains three files:
- *train* stores the triples in the training set. In format of a triple, e.g., [src_name, rel_name, dst_name] and should follow the order specified in [h|r|t]
- *valid* stores the triples in the validation set. In format of a triple, e.g., [src_name, rel_name, dst_name] and should follow the order specified in [h|r|t]
- *test* stores the triples in the test set. In format of a triple, e.g., [src_name, rel_name, dst_name] and should follow the order specified in [h|r|t]
Format 2:
- entity2id.txt contains pairs of (entity name, entity Id). The number of rows is the number of entities (nodes).
- relation2id.txt contains pairs of (relation name, relation Id). The number of rows is the number of relations.
- train.txt stores edges in the training set. They are stored as triples of (head, tail, rel).
- valid.txt stores edges in the validation set. They are stored as a triple of (head, tail, rel).
- test.txt stores edges in the test set. They are stored as a triple of (head, tail, rel).
- udd_[h|r|t], user defined dataset. In this format, user should provide the id mapping for entities and relations. The order of head, relation and tail entities are described in [h|r|t], for example, raw_udd_trh means the triples are stored in the order of tail, relation and head. It should contains five files:
- *entities* stores the mapping between entity name and entity Id
- *relations* stores the mapping between relation name relation Id
- *train* stores the triples in the training set. In format of a triple, e.g., [src_id, rel_id, dst_id] and should follow the order specified in [h|r|t]
- *valid* stores the triples in the validation set. In format of a triple, e.g., [src_id, rel_id, dst_id] and should follow the order specified in [h|r|t]
- *test* stores the triples in the test set. In format of a triple, e.g., [src_id, rel_id, dst_id] and should follow the order specified in [h|r|t]
### Output formats:
......
......@@ -23,63 +23,86 @@ def _download_and_extract(url, path, filename):
writer.write(chunk)
print('Download finished. Unzipping the file...')
class KGDataset1:
'''Load a knowledge graph with format 1
In this format, the folder with a knowledge graph has five files:
* entities.dict stores the mapping between entity Id and entity name.
* relations.dict stores the mapping between relation Id and relation name.
* train.txt stores the triples in the training set.
* valid.txt stores the triples in the validation set.
* test.txt stores the triples in the test set.
def _get_id(dict, key):
id = dict.get(key, None)
if id is None:
id = len(dict)
dict[key] = id
return id
def _parse_srd_format(format):
if format == "hrt":
return [0, 1, 2]
if format == "htr":
return [0, 2, 1]
if format == "rht":
return [1, 0, 2]
if format == "rth":
return [2, 0, 1]
if format == "thr":
return [1, 2, 0]
if format == "trh":
return [2, 1, 0]
class KGDataset:
'''Load a knowledge graph
The folder with a knowledge graph has five files:
* entities stores the mapping between entity Id and entity name.
* relations stores the mapping between relation Id and relation name.
* train stores the triples in the training set.
* valid stores the triples in the validation set.
* test stores the triples in the test set.
The mapping between entity (relation) Id and entity (relation) name is stored as 'id\tname'.
The triples are stored as 'head_name\trelation_name\ttail_name'.
'''
def __init__(self, path, name, read_triple=True, only_train=False):
url = 'https://data.dgl.ai/dataset/{}.zip'.format(name)
def __init__(self, entity_path, relation_path,
train_path, valid_path=None, test_path=None,
format=[0,1,2], read_triple=True, only_train=False,
skip_first_line=False):
self.entity2id, self.n_entities = self.read_entity(entity_path)
self.relation2id, self.n_relations = self.read_relation(relation_path)
if not os.path.exists(os.path.join(path, name)):
print('File not found. Downloading from', url)
_download_and_extract(url, path, name + '.zip')
path = os.path.join(path, name)
if read_triple == True:
self.train = self.read_triple(train_path, "train", skip_first_line, format)
if only_train == False:
self.valid = self.read_triple(valid_path, "valid", skip_first_line, format)
self.test = self.read_triple(test_path, "test", skip_first_line, format)
with open(os.path.join(path, 'entities.dict')) as f:
def read_entity(self, entity_path):
with open(entity_path) as f:
entity2id = {}
for line in f:
eid, entity = line.strip().split('\t')
entity2id[entity] = int(eid)
self.entity2id = entity2id
return entity2id, len(entity2id)
with open(os.path.join(path, 'relations.dict')) as f:
def read_relation(self, relation_path):
with open(relation_path) as f:
relation2id = {}
for line in f:
rid, relation = line.strip().split('\t')
relation2id[relation] = int(rid)
self.relation2id = relation2id
return relation2id, len(relation2id)
# TODO: to deal with contries dataset.
self.n_entities = len(self.entity2id)
self.n_relations = len(self.relation2id)
if read_triple == True:
self.train = self.read_triple(path, 'train')
if only_train == False:
self.valid = self.read_triple(path, 'valid')
self.test = self.read_triple(path, 'test')
def read_triple(self, path, mode):
def read_triple(self, path, mode, skip_first_line=False, format=[0,1,2]):
# mode: train/valid/test
if path is None:
return None
heads = []
tails = []
rels = []
with open(os.path.join(path, '{}.txt'.format(mode))) as f:
with open(path) as f:
if skip_first_line:
_ = f.readline()
for line in f:
h, r, t = line.strip().split('\t')
triple = line.strip().split('\t')
h, r, t = triple[format[0]], triple[format[1]], triple[format[2]]
heads.append(self.entity2id[h])
rels.append(self.relation2id[r])
tails.append(self.entity2id[t])
......@@ -89,11 +112,130 @@ class KGDataset1:
return (heads, rels, tails)
class KGDatasetFB15k(KGDataset):
'''Load a knowledge graph FB15k
The FB15k dataset has five files:
* entities.dict stores the mapping between entity Id and entity name.
* relations.dict stores the mapping between relation Id and relation name.
* train.txt stores the triples in the training set.
* valid.txt stores the triples in the validation set.
* test.txt stores the triples in the test set.
The mapping between entity (relation) name and entity (relation) Id is stored as 'name\tid'.
The triples are stored as 'head_nid\trelation_id\ttail_nid'.
'''
def __init__(self, path, name='FB15k', read_triple=True, only_train=False):
self.name = name
url = 'https://data.dgl.ai/dataset/{}.zip'.format(name)
if not os.path.exists(os.path.join(path, name)):
print('File not found. Downloading from', url)
_download_and_extract(url, path, name + '.zip')
self.path = os.path.join(path, name)
super(KGDatasetFB15k, self).__init__(os.path.join(self.path, 'entities.dict'),
os.path.join(self.path, 'relations.dict'),
os.path.join(self.path, 'train.txt'),
os.path.join(self.path, 'valid.txt'),
os.path.join(self.path, 'test.txt'),
read_triple=read_triple,
only_train=only_train)
class KGDataset2:
'''Load a knowledge graph with format 2
class KGDatasetFB15k237(KGDataset):
'''Load a knowledge graph FB15k-237
The FB15k-237 dataset has five files:
* entities.dict stores the mapping between entity Id and entity name.
* relations.dict stores the mapping between relation Id and relation name.
* train.txt stores the triples in the training set.
* valid.txt stores the triples in the validation set.
* test.txt stores the triples in the test set.
In this format, the folder with a knowledge graph has five files:
The mapping between entity (relation) name and entity (relation) Id is stored as 'name\tid'.
The triples are stored as 'head_nid\trelation_id\ttail_nid'.
'''
def __init__(self, path, name='FB15k-237', read_triple=True, only_train=False):
self.name = name
url = 'https://data.dgl.ai/dataset/{}.zip'.format(name)
if not os.path.exists(os.path.join(path, name)):
print('File not found. Downloading from', url)
_download_and_extract(url, path, name + '.zip')
self.path = os.path.join(path, name)
super(KGDatasetFB15k237, self).__init__(os.path.join(self.path, 'entities.dict'),
os.path.join(self.path, 'relations.dict'),
os.path.join(self.path, 'train.txt'),
os.path.join(self.path, 'valid.txt'),
os.path.join(self.path, 'test.txt'),
read_triple=read_triple,
only_train=only_train)
class KGDatasetWN18(KGDataset):
'''Load a knowledge graph wn18
The wn18 dataset has five files:
* entities.dict stores the mapping between entity Id and entity name.
* relations.dict stores the mapping between relation Id and relation name.
* train.txt stores the triples in the training set.
* valid.txt stores the triples in the validation set.
* test.txt stores the triples in the test set.
The mapping between entity (relation) name and entity (relation) Id is stored as 'name\tid'.
The triples are stored as 'head_nid\trelation_id\ttail_nid'.
'''
def __init__(self, path, name='wn18', read_triple=True, only_train=False):
self.name = name
url = 'https://data.dgl.ai/dataset/{}.zip'.format(name)
if not os.path.exists(os.path.join(path, name)):
print('File not found. Downloading from', url)
_download_and_extract(url, path, name + '.zip')
self.path = os.path.join(path, name)
super(KGDatasetWN18, self).__init__(os.path.join(self.path, 'entities.dict'),
os.path.join(self.path, 'relations.dict'),
os.path.join(self.path, 'train.txt'),
os.path.join(self.path, 'valid.txt'),
os.path.join(self.path, 'test.txt'),
read_triple=read_triple,
only_train=only_train)
class KGDatasetWN18rr(KGDataset):
'''Load a knowledge graph wn18rr
The wn18rr dataset has five files:
* entities.dict stores the mapping between entity Id and entity name.
* relations.dict stores the mapping between relation Id and relation name.
* train.txt stores the triples in the training set.
* valid.txt stores the triples in the validation set.
* test.txt stores the triples in the test set.
The mapping between entity (relation) name and entity (relation) Id is stored as 'name\tid'.
The triples are stored as 'head_nid\trelation_id\ttail_nid'.
'''
def __init__(self, path, name='wn18rr', read_triple=True, only_train=False):
self.name = name
url = 'https://data.dgl.ai/dataset/{}.zip'.format(name)
if not os.path.exists(os.path.join(path, name)):
print('File not found. Downloading from', url)
_download_and_extract(url, path, name + '.zip')
self.path = os.path.join(path, name)
super(KGDatasetWN18rr, self).__init__(os.path.join(self.path, 'entities.dict'),
os.path.join(self.path, 'relations.dict'),
os.path.join(self.path, 'train.txt'),
os.path.join(self.path, 'valid.txt'),
os.path.join(self.path, 'test.txt'),
read_triple=read_triple,
only_train=only_train)
class KGDatasetFreebase(KGDataset):
'''Load a knowledge graph Full Freebase
The Freebase dataset has five files:
* entity2id.txt stores the mapping between entity name and entity Id.
* relation2id.txt stores the mapping between relation name relation Id.
* train.txt stores the triples in the training set.
......@@ -101,10 +243,10 @@ class KGDataset2:
* test.txt stores the triples in the test set.
The mapping between entity (relation) name and entity (relation) Id is stored as 'name\tid'.
The triples are stored as 'head_nid\trelation_id\ttail_nid'.
'''
def __init__(self, path, name, read_triple=True, only_train=False):
def __init__(self, path, name='Freebase', read_triple=True, only_train=False):
self.name = name
url = 'https://data.dgl.ai/dataset/{}.zip'.format(name)
if not os.path.exists(os.path.join(path, name)):
......@@ -112,31 +254,30 @@ class KGDataset2:
_download_and_extract(url, path, '{}.zip'.format(name))
self.path = os.path.join(path, name)
f_rel2id = os.path.join(self.path, 'relation2id.txt')
with open(f_rel2id) as f_rel:
self.n_relations = int(f_rel.readline()[:-1])
super(KGDatasetFreebase, self).__init__(os.path.join(self.path, 'entity2id.txt'),
os.path.join(self.path, 'relation2id.txt'),
os.path.join(self.path, 'train.txt'),
os.path.join(self.path, 'valid.txt'),
os.path.join(self.path, 'test.txt'),
read_triple=read_triple,
only_train=only_train)
if only_train == True:
f_ent2id = os.path.join(self.path, 'local_to_global.txt')
with open(f_ent2id) as f_ent:
self.n_entities = len(f_ent.readlines())
else:
f_ent2id = os.path.join(self.path, 'entity2id.txt')
with open(f_ent2id) as f_ent:
self.n_entities = int(f_ent.readline()[:-1])
def read_entity(self, entity_path):
with open(entity_path) as f_ent:
n_entities = int(f_ent.readline()[:-1])
return None, n_entities
if read_triple == True:
self.train = self.read_triple(self.path, 'train')
if only_train == False:
self.valid = self.read_triple(self.path, 'valid')
self.test = self.read_triple(self.path, 'test')
def read_relation(self, relation_path):
with open(relation_path) as f_rel:
n_relations = int(f_rel.readline()[:-1])
return None, n_relations
def read_triple(self, path, mode, skip_first_line=False):
heads = []
tails = []
rels = []
print('Reading {} triples....'.format(mode))
with open(os.path.join(path, '{}.txt'.format(mode))) as f:
with open(path) as f:
if skip_first_line:
_ = f.readline()
for line in f:
......@@ -150,27 +291,203 @@ class KGDataset2:
print('Finished. Read {} {} triples.'.format(len(heads), mode))
return (heads, rels, tails)
class KGDatasetUDDRaw(KGDataset):
'''Load a knowledge graph user defined dataset
The user defined dataset has five files:
* entities stores the mapping between entity name and entity Id.
* relations stores the mapping between relation name relation Id.
* train stores the triples in the training set. In format [src_name, rel_name, dst_name]
* valid stores the triples in the validation set. In format [src_name, rel_name, dst_name]
* test stores the triples in the test set. In format [src_name, rel_name, dst_name]
The mapping between entity (relation) name and entity (relation) Id is stored as 'name\tid'.
The triples are stored as 'head_nid\trelation_id\ttail_nid'.
'''
def __init__(self, path, name, files, format):
self.name = name
for f in files:
assert os.path.exists(os.path.join(path, f)), \
'File {} now exist in {}'.format(f, path)
assert len(format) == 3
format = _parse_srd_format(format)
self.load_entity_relation(path, files, format)
# Only train set is provided
if len(files) == 1:
super(KGDatasetUDDRaw, self).__init__("entities.tsv",
"relation.tsv",
os.path.join(path, files[0]),
format=format,
read_triple=True,
only_train=True)
# Train, validation and test set are provided
if len(files) == 3:
super(KGDatasetUDDRaw, self).__init__("entities.tsv",
"relation.tsv",
os.path.join(path, files[0]),
os.path.join(path, files[1]),
os.path.join(path, files[2]),
format=format,
read_triple=True,
only_train=False)
def load_entity_relation(self, path, files, format):
entity_map = {}
rel_map = {}
for fi in files:
with open(os.path.join(path, fi)) as f:
for line in f:
triple = line.strip().split('\t')
src, rel, dst = triple[format[0]], triple[format[1]], triple[format[2]]
src_id = _get_id(entity_map, src)
dst_id = _get_id(entity_map, dst)
rel_id = _get_id(rel_map, rel)
entities = ["{}\t{}\n".format(key, val) for key, val in entity_map.items()]
with open(os.path.join(path, "entities.tsv"), "w+") as f:
f.writelines(entities)
self.entity2id = entity_map
self.n_entities = len(entities)
relations = ["{}\t{}\n".format(key, val) for key, val in rel_map.items()]
with open(os.path.join(path, "relations.tsv"), "w+") as f:
f.writelines(relations)
self.relation2id = rel_map
self.n_relations = len(relations)
def read_entity(self, entity_path):
return self.entity2id, self.n_entities
def read_relation(self, relation_path):
return self.relation2id, self.n_relations
class KGDatasetUDD(KGDataset):
'''Load a knowledge graph user defined dataset
The user defined dataset has five files:
* entities stores the mapping between entity name and entity Id.
* relations stores the mapping between relation name relation Id.
* train stores the triples in the training set. In format [src_id, rel_id, dst_id]
* valid stores the triples in the validation set. In format [src_id, rel_id, dst_id]
* test stores the triples in the test set. In format [src_id, rel_id, dst_id]
The mapping between entity (relation) name and entity (relation) Id is stored as 'name\tid'.
The triples are stored as 'head_nid\trelation_id\ttail_nid'.
'''
def __init__(self, path, name, files, format, read_triple=True, only_train=False):
self.name = name
for f in files:
assert os.path.exists(os.path.join(path, f)), \
'File {} now exist in {}'.format(f, path)
format = _parse_srd_format(format)
if len(files) == 3:
super(KGDatasetUDD, self).__init__(os.path.join(path, files[0]),
os.path.join(path, files[1]),
os.path.join(path, files[2]),
os.path.join(path, None),
os.path.join(path, None),
format=format,
read_triple=read_triple,
only_train=only_train)
if len(files) == 5:
super(KGDatasetUDD, self).__init__(os.path.join(path, files[0]),
os.path.join(path, files[1]),
os.path.join(path, files[2]),
os.path.join(path, files[3]),
os.path.join(path, files[4]),
format=format,
read_triple=read_triple,
only_train=only_train)
def read_entity(self, entity_path):
n_entities = 0
with open(entity_path) as f_ent:
for line in f_ent:
n_entities += 1
return None, n_entities
def read_relation(self, relation_path):
n_relations = 0
with open(relation_path) as f_rel:
for line in f_rel:
n_relations += 1
return None, n_relations
def read_triple(self, path, mode, skip_first_line=False, format=[0,1,2]):
heads = []
tails = []
rels = []
print('Reading {} triples....'.format(mode))
with open(path) as f:
if skip_first_line:
_ = f.readline()
for line in f:
triple = line.strip().split('\t')
h, r, t = triple[format[0]], triple[format[1]], triple[format[2]]
heads.append(int(h))
tails.append(int(t))
rels.append(int(r))
heads = np.array(heads, dtype=np.int64)
tails = np.array(tails, dtype=np.int64)
rels = np.array(rels, dtype=np.int64)
print('Finished. Read {} {} triples.'.format(len(heads), mode))
return (heads, rels, tails)
def get_dataset(data_path, data_name, format_str):
if data_name == 'Freebase':
dataset = KGDataset2(data_path, data_name)
elif format_str == '1':
dataset = KGDataset1(data_path, data_name)
def get_dataset(data_path, data_name, format_str, files=None):
if format_str == 'built_in':
if data_name == 'Freebase':
dataset = KGDatasetFreebase(data_path)
elif data_name == 'FB15k':
dataset = KGDatasetFB15k(data_path)
elif data_name == 'FB15k-237':
dataset = KGDatasetFB15k237(data_path)
elif data_name == 'wn18':
dataset = KGDatasetWN18(data_path)
elif data_name == 'wn18rr':
dataset = KGDatasetWN18rr(data_path)
else:
assert False, "Unknown dataset {}".format(data_name)
elif format_str.startswith('raw_udd'):
# user defined dataset
format = format_str[8:]
dataset = KGDatasetUDDRaw(data_path, data_name, files, format)
elif format_str.startswith('udd'):
# user defined dataset
format = format_str[4:]
dataset = KGDatasetUDD(data_path, data_name, files, format)
else:
dataset = KGDataset2(data_path, data_name)
assert False, "Unknown format {}".format(format_str)
return dataset
def get_partition_dataset(data_path, data_name, format_str, part_id):
part_name = os.path.join(data_name, 'part_'+str(part_id))
if data_name == 'Freebase':
dataset = KGDataset2(data_path, part_name, read_triple=True, only_train=True)
elif format_str == '1':
dataset = KGDataset1(data_path, part_name, read_triple=True, only_train=True)
if format_str == 'built_in':
if data_name == 'Freebase':
dataset = KGDatasetFreebase(data_path, part_name, read_triple=True, only_train=True)
elif data_name == 'FB15k':
dataset = KGDatasetFB15k(data_path, part_name, read_triple=True, only_train=True)
elif data_name == 'FB15k-237':
dataset = KGDatasetFB15k237(data_path, part_name, read_triple=True, only_train=True)
elif data_name == 'wn18':
dataset = KGDatasetWN18(data_path, part_name, read_triple=True, only_train=True)
elif data_name == 'wn18rr':
dataset = KGDatasetWN18rr(data_path, part_name, read_triple=True, only_train=True)
else:
assert False, "Unknown dataset {}".format(data_name)
elif format_str == 'raw_udd':
# user defined dataset
assert False, "When using partitioned dataset, we assume dataset will not be raw"
elif format_str == 'udd':
# user defined dataset
format = format_str[4:]
dataset = KGDatasetUDD(data_path, data_name, files, format, read_triple=True, only_train=True)
else:
dataset = KGDataset2(data_path, part_name, read_triple=True, only_train=True)
assert False, "Unknown format {}".format(format_str)
path = os.path.join(data_path, part_name)
......@@ -186,16 +503,31 @@ def get_partition_dataset(data_path, data_name, format_str, part_id):
return dataset, partition_book, local_to_global
def get_server_partition_dataset(data_path, data_name, format_str, part_id):
part_name = os.path.join(data_name, 'part_'+str(part_id))
if data_name == 'Freebase':
dataset = KGDataset2(data_path, part_name, read_triple=False, only_train=True)
elif format_str == '1':
dataset = KGDataset1(data_path, part_name, read_triple=False, only_train=True)
if format_str == 'built_in':
if data_name == 'Freebase':
dataset = KGDatasetFreebase(data_path, part_name, read_triple=False, only_train=True)
elif data_name == 'FB15k':
dataset = KGDatasetFB15k(data_path, part_name, read_triple=False, only_train=True)
elif data_name == 'FB15k-237':
dataset = KGDatasetFB15k237(data_path, part_name, read_triple=False, only_train=True)
elif data_name == 'wn18':
dataset = KGDatasetWN18(data_path, part_name, read_triple=False, only_train=True)
elif data_name == 'wn18rr':
dataset = KGDatasetWN18rr(data_path, part_name, read_triple=False, only_train=True)
else:
assert False, "Unknown dataset {}".format(data_name)
elif format_str == 'raw_udd':
# user defined dataset
assert False, "When using partitioned dataset, we assume dataset will not be raw"
elif format_str == 'udd':
# user defined dataset
format = format_str[4:]
dataset = KGDatasetUDD(data_path, data_name, files, format, read_triple=False, only_train=True)
else:
dataset = KGDataset2(data_path, part_name, read_triple=False, only_train=True)
assert False, "Unknown format {}".format(format_str)
path = os.path.join(data_path, part_name)
......
......@@ -29,8 +29,11 @@ class ArgParser(argparse.ArgumentParser):
help='root path of all dataset')
self.add_argument('--dataset', type=str, default='FB15k',
help='dataset name, under data_path')
self.add_argument('--format', type=str, default='1',
help='the format of the dataset.')
self.add_argument('--format', type=str, default='built_in',
help='the format of the dataset, it can be built_in,'\
'raw_udd_{htr} and udd_{htr}')
self.add_argument('--data_files', type=str, default=None, nargs='+',
help='a list of data files, e.g. entity relation train valid test')
self.add_argument('--model_path', type=str, default='ckpts',
help='the place where models are saved')
......@@ -97,7 +100,7 @@ def main(args):
assert not args.eval_filter, "if negative sampling based on degree, we can't filter positive edges."
# load dataset and samplers
dataset = get_dataset(args.data_path, args.dataset, args.format)
dataset = get_dataset(args.data_path, args.dataset, args.format, args.data_files)
args.pickle_graph = False
args.train = False
args.valid = False
......
......@@ -31,8 +31,11 @@ class ArgParser(argparse.ArgumentParser):
help='root path of all dataset')
self.add_argument('--dataset', type=str, default='FB15k',
help='dataset name, under data_path')
self.add_argument('--format', type=str, default='1',
help='the format of the dataset.')
self.add_argument('--format', type=str, default='built_in',
help='the format of the dataset, it can be built_in,'\
'raw_udd_{htr} and udd_{htr}')
self.add_argument('--data_files', type=str, default=None, nargs='+',
help='a list of data files, e.g. entity relation train valid test')
self.add_argument('--save_path', type=str, default='ckpts',
help='place to save models and logs')
self.add_argument('--save_emb', type=str, default=None,
......@@ -156,7 +159,7 @@ def get_logger(args):
def run(args, logger):
train_time_start = time.time()
# load dataset and samplers
dataset = get_dataset(args.data_path, args.dataset, args.format)
dataset = get_dataset(args.data_path, args.dataset, args.format, args.data_files)
n_entities = dataset.n_entities
n_relations = dataset.n_relations
if args.neg_sample_size_test < 0:
......
......@@ -189,7 +189,6 @@ def test(args, model, test_samplers, rank=0, mode='Test', queue=None):
with th.no_grad():
logs = []
for sampler in test_samplers:
count = 0
for pos_g, neg_g in sampler:
model.forward_test(pos_g, neg_g, logs, gpu_id)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment