"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "3ecbbd628825712552a1834f78a224bfab4cfe21"
Unverified Commit f9ba3cdc authored by xiang song(charlie.song)'s avatar xiang song(charlie.song) Committed by GitHub
Browse files

[KGE] Kg loader (#1300)



* Now dataset accepts user define datasets

* UPdate README

* Fix eval

* Fix

* Fix Freebase

* Fix

* Fix

* upd

* upd

* Update README

* Update some docstrings.

* upd
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-8-26.us-east-2.compute.internal>
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-19-78.us-east-2.compute.internal>
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-60-78.ec2.internal>
parent 614acf2c
...@@ -42,9 +42,9 @@ We will support multi-GPU training and distributed training in a near future. ...@@ -42,9 +42,9 @@ We will support multi-GPU training and distributed training in a near future.
The package can run with both Pytorch and MXNet. For Pytorch, it works with Pytorch v1.2 or newer. The package can run with both Pytorch and MXNet. For Pytorch, it works with Pytorch v1.2 or newer.
For MXNet, it works with MXNet 1.5 or newer. For MXNet, it works with MXNet 1.5 or newer.
## Datasets ## Built-in Datasets
DGL-KE provides five knowledge graphs: DGL-KE provides five built-in knowledge graphs:
| Dataset | #nodes | #edges | #relations | | Dataset | #nodes | #edges | #relations |
|---------|--------|--------|------------| |---------|--------|--------|------------|
...@@ -129,24 +129,20 @@ when given (?, rel, tail). ...@@ -129,24 +129,20 @@ when given (?, rel, tail).
### Input formats: ### Input formats:
DGL-KE supports two knowledge graph input formats. A knowledge graph is stored DGL-KE supports two knowledge graph input formats for user defined dataset
using five files.
Format 1: - raw_udd_[h|r|t], raw user defined dataset. In this format, user only need to provide triples and let the dataloader generate and manipulate the id mapping. The dataloader will generate two files: entities.tsv for entity id mapping and relations.tsv for relation id mapping. The order of head, relation and tail entities are described in [h|r|t], for example, raw_udd_trh means the triples are stored in the order of tail, relation and head. It should contains three files:
- *train* stores the triples in the training set. In format of a triple, e.g., [src_name, rel_name, dst_name] and should follow the order specified in [h|r|t]
- entities.dict contains pairs of (entity Id, entity name). The number of rows is the number of entities (nodes). - *valid* stores the triples in the validation set. In format of a triple, e.g., [src_name, rel_name, dst_name] and should follow the order specified in [h|r|t]
- relations.dict contains pairs of (relation Id, relation name). The number of rows is the number of relations. - *test* stores the triples in the test set. In format of a triple, e.g., [src_name, rel_name, dst_name] and should follow the order specified in [h|r|t]
- train.txt stores edges in the training set. They are stored as triples of (head, rel, tail).
- valid.txt stores edges in the validation set. They are stored as triples of (head, rel, tail).
- test.txt stores edges in the test set. They are stored as triples of (head, rel, tail).
Format 2: Format 2:
- udd_[h|r|t], user defined dataset. In this format, user should provide the id mapping for entities and relations. The order of head, relation and tail entities are described in [h|r|t], for example, raw_udd_trh means the triples are stored in the order of tail, relation and head. It should contains five files:
- entity2id.txt contains pairs of (entity name, entity Id). The number of rows is the number of entities (nodes). - *entities* stores the mapping between entity name and entity Id
- relation2id.txt contains pairs of (relation name, relation Id). The number of rows is the number of relations. - *relations* stores the mapping between relation name relation Id
- train.txt stores edges in the training set. They are stored as triples of (head, tail, rel). - *train* stores the triples in the training set. In format of a triple, e.g., [src_id, rel_id, dst_id] and should follow the order specified in [h|r|t]
- valid.txt stores edges in the validation set. They are stored as a triple of (head, tail, rel). - *valid* stores the triples in the validation set. In format of a triple, e.g., [src_id, rel_id, dst_id] and should follow the order specified in [h|r|t]
- test.txt stores edges in the test set. They are stored as a triple of (head, tail, rel). - *test* stores the triples in the test set. In format of a triple, e.g., [src_id, rel_id, dst_id] and should follow the order specified in [h|r|t]
### Output formats: ### Output formats:
......
This diff is collapsed.
...@@ -29,8 +29,11 @@ class ArgParser(argparse.ArgumentParser): ...@@ -29,8 +29,11 @@ class ArgParser(argparse.ArgumentParser):
help='root path of all dataset') help='root path of all dataset')
self.add_argument('--dataset', type=str, default='FB15k', self.add_argument('--dataset', type=str, default='FB15k',
help='dataset name, under data_path') help='dataset name, under data_path')
self.add_argument('--format', type=str, default='1', self.add_argument('--format', type=str, default='built_in',
help='the format of the dataset.') help='the format of the dataset, it can be built_in,'\
'raw_udd_{htr} and udd_{htr}')
self.add_argument('--data_files', type=str, default=None, nargs='+',
help='a list of data files, e.g. entity relation train valid test')
self.add_argument('--model_path', type=str, default='ckpts', self.add_argument('--model_path', type=str, default='ckpts',
help='the place where models are saved') help='the place where models are saved')
...@@ -97,7 +100,7 @@ def main(args): ...@@ -97,7 +100,7 @@ def main(args):
assert not args.eval_filter, "if negative sampling based on degree, we can't filter positive edges." assert not args.eval_filter, "if negative sampling based on degree, we can't filter positive edges."
# load dataset and samplers # load dataset and samplers
dataset = get_dataset(args.data_path, args.dataset, args.format) dataset = get_dataset(args.data_path, args.dataset, args.format, args.data_files)
args.pickle_graph = False args.pickle_graph = False
args.train = False args.train = False
args.valid = False args.valid = False
......
...@@ -31,8 +31,11 @@ class ArgParser(argparse.ArgumentParser): ...@@ -31,8 +31,11 @@ class ArgParser(argparse.ArgumentParser):
help='root path of all dataset') help='root path of all dataset')
self.add_argument('--dataset', type=str, default='FB15k', self.add_argument('--dataset', type=str, default='FB15k',
help='dataset name, under data_path') help='dataset name, under data_path')
self.add_argument('--format', type=str, default='1', self.add_argument('--format', type=str, default='built_in',
help='the format of the dataset.') help='the format of the dataset, it can be built_in,'\
'raw_udd_{htr} and udd_{htr}')
self.add_argument('--data_files', type=str, default=None, nargs='+',
help='a list of data files, e.g. entity relation train valid test')
self.add_argument('--save_path', type=str, default='ckpts', self.add_argument('--save_path', type=str, default='ckpts',
help='place to save models and logs') help='place to save models and logs')
self.add_argument('--save_emb', type=str, default=None, self.add_argument('--save_emb', type=str, default=None,
...@@ -156,7 +159,7 @@ def get_logger(args): ...@@ -156,7 +159,7 @@ def get_logger(args):
def run(args, logger): def run(args, logger):
train_time_start = time.time() train_time_start = time.time()
# load dataset and samplers # load dataset and samplers
dataset = get_dataset(args.data_path, args.dataset, args.format) dataset = get_dataset(args.data_path, args.dataset, args.format, args.data_files)
n_entities = dataset.n_entities n_entities = dataset.n_entities
n_relations = dataset.n_relations n_relations = dataset.n_relations
if args.neg_sample_size_test < 0: if args.neg_sample_size_test < 0:
......
...@@ -189,7 +189,6 @@ def test(args, model, test_samplers, rank=0, mode='Test', queue=None): ...@@ -189,7 +189,6 @@ def test(args, model, test_samplers, rank=0, mode='Test', queue=None):
with th.no_grad(): with th.no_grad():
logs = [] logs = []
for sampler in test_samplers: for sampler in test_samplers:
count = 0
for pos_g, neg_g in sampler: for pos_g, neg_g in sampler:
model.forward_test(pos_g, neg_g, logs, gpu_id) model.forward_test(pos_g, neg_g, logs, gpu_id)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment