Unverified Commit f9ba3cdc authored by xiang song(charlie.song)'s avatar xiang song(charlie.song) Committed by GitHub
Browse files

[KGE] Kg loader (#1300)



* Now dataset accepts user define datasets

* UPdate README

* Fix eval

* Fix

* Fix Freebase

* Fix

* Fix

* upd

* upd

* Update README

* Update some docstrings.

* upd
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-8-26.us-east-2.compute.internal>
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-19-78.us-east-2.compute.internal>
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-60-78.ec2.internal>
parent 614acf2c
......@@ -42,9 +42,9 @@ We will support multi-GPU training and distributed training in a near future.
The package can run with both Pytorch and MXNet. For Pytorch, it works with Pytorch v1.2 or newer.
For MXNet, it works with MXNet 1.5 or newer.
## Datasets
## Built-in Datasets
DGL-KE provides five knowledge graphs:
DGL-KE provides five built-in knowledge graphs:
| Dataset | #nodes | #edges | #relations |
|---------|--------|--------|------------|
......@@ -129,24 +129,20 @@ when given (?, rel, tail).
### Input formats:
DGL-KE supports two knowledge graph input formats. A knowledge graph is stored
using five files.
DGL-KE supports two knowledge graph input formats for user defined dataset
Format 1:
- entities.dict contains pairs of (entity Id, entity name). The number of rows is the number of entities (nodes).
- relations.dict contains pairs of (relation Id, relation name). The number of rows is the number of relations.
- train.txt stores edges in the training set. They are stored as triples of (head, rel, tail).
- valid.txt stores edges in the validation set. They are stored as triples of (head, rel, tail).
- test.txt stores edges in the test set. They are stored as triples of (head, rel, tail).
- raw_udd_[h|r|t], raw user defined dataset. In this format, user only need to provide triples and let the dataloader generate and manipulate the id mapping. The dataloader will generate two files: entities.tsv for entity id mapping and relations.tsv for relation id mapping. The order of head, relation and tail entities are described in [h|r|t], for example, raw_udd_trh means the triples are stored in the order of tail, relation and head. It should contains three files:
- *train* stores the triples in the training set. In format of a triple, e.g., [src_name, rel_name, dst_name] and should follow the order specified in [h|r|t]
- *valid* stores the triples in the validation set. In format of a triple, e.g., [src_name, rel_name, dst_name] and should follow the order specified in [h|r|t]
- *test* stores the triples in the test set. In format of a triple, e.g., [src_name, rel_name, dst_name] and should follow the order specified in [h|r|t]
Format 2:
- entity2id.txt contains pairs of (entity name, entity Id). The number of rows is the number of entities (nodes).
- relation2id.txt contains pairs of (relation name, relation Id). The number of rows is the number of relations.
- train.txt stores edges in the training set. They are stored as triples of (head, tail, rel).
- valid.txt stores edges in the validation set. They are stored as a triple of (head, tail, rel).
- test.txt stores edges in the test set. They are stored as a triple of (head, tail, rel).
- udd_[h|r|t], user defined dataset. In this format, user should provide the id mapping for entities and relations. The order of head, relation and tail entities are described in [h|r|t], for example, raw_udd_trh means the triples are stored in the order of tail, relation and head. It should contains five files:
- *entities* stores the mapping between entity name and entity Id
- *relations* stores the mapping between relation name relation Id
- *train* stores the triples in the training set. In format of a triple, e.g., [src_id, rel_id, dst_id] and should follow the order specified in [h|r|t]
- *valid* stores the triples in the validation set. In format of a triple, e.g., [src_id, rel_id, dst_id] and should follow the order specified in [h|r|t]
- *test* stores the triples in the test set. In format of a triple, e.g., [src_id, rel_id, dst_id] and should follow the order specified in [h|r|t]
### Output formats:
......
This diff is collapsed.
......@@ -29,8 +29,11 @@ class ArgParser(argparse.ArgumentParser):
help='root path of all dataset')
self.add_argument('--dataset', type=str, default='FB15k',
help='dataset name, under data_path')
self.add_argument('--format', type=str, default='1',
help='the format of the dataset.')
self.add_argument('--format', type=str, default='built_in',
help='the format of the dataset, it can be built_in,'\
'raw_udd_{htr} and udd_{htr}')
self.add_argument('--data_files', type=str, default=None, nargs='+',
help='a list of data files, e.g. entity relation train valid test')
self.add_argument('--model_path', type=str, default='ckpts',
help='the place where models are saved')
......@@ -97,7 +100,7 @@ def main(args):
assert not args.eval_filter, "if negative sampling based on degree, we can't filter positive edges."
# load dataset and samplers
dataset = get_dataset(args.data_path, args.dataset, args.format)
dataset = get_dataset(args.data_path, args.dataset, args.format, args.data_files)
args.pickle_graph = False
args.train = False
args.valid = False
......
......@@ -31,8 +31,11 @@ class ArgParser(argparse.ArgumentParser):
help='root path of all dataset')
self.add_argument('--dataset', type=str, default='FB15k',
help='dataset name, under data_path')
self.add_argument('--format', type=str, default='1',
help='the format of the dataset.')
self.add_argument('--format', type=str, default='built_in',
help='the format of the dataset, it can be built_in,'\
'raw_udd_{htr} and udd_{htr}')
self.add_argument('--data_files', type=str, default=None, nargs='+',
help='a list of data files, e.g. entity relation train valid test')
self.add_argument('--save_path', type=str, default='ckpts',
help='place to save models and logs')
self.add_argument('--save_emb', type=str, default=None,
......@@ -156,7 +159,7 @@ def get_logger(args):
def run(args, logger):
train_time_start = time.time()
# load dataset and samplers
dataset = get_dataset(args.data_path, args.dataset, args.format)
dataset = get_dataset(args.data_path, args.dataset, args.format, args.data_files)
n_entities = dataset.n_entities
n_relations = dataset.n_relations
if args.neg_sample_size_test < 0:
......
......@@ -189,7 +189,6 @@ def test(args, model, test_samplers, rank=0, mode='Test', queue=None):
with th.no_grad():
logs = []
for sampler in test_samplers:
count = 0
for pos_g, neg_g in sampler:
model.forward_test(pos_g, neg_g, logs, gpu_id)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment