Unverified Commit 1c91f460 authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

[Hetero][Model] RGCN for heterogeneous input (#885)

* new hetero RGCN

* bgs running

* fix gpu

* am dataset

* fix bug in label preparation

* Fix AM training; add result

* rm sym link

* new embed layer; mutag

* mutag matched; other fix

* minor fix

* dataset refactor

* new data loading

* rm old files

* refactor

* docstring

* include literal nodes in AIFB dataset

* address comments

* docstring
parent 65e1ba4f
# Relational-GCN
* Paper: [https://arxiv.org/abs/1703.06103](https://arxiv.org/abs/1703.06103)
* Author's code for entity classification: [https://github.com/tkipf/relational-gcn](https://github.com/tkipf/relational-gcn)
* Author's code for link prediction: [https://github.com/MichSchli/RelationPrediction](https://github.com/MichSchli/RelationPrediction)
The preprocessing is slightly different from the author's code. We directly load and preprocess
raw RDF data. For AIFB, BGS and AM,
all literal nodes are pruned from the graph. For AIFB, some training/testing nodes
thus become orphan and are excluded from the training/testing set. The resulting graph
has fewer entities and relations. As a reference (numbers include reverse edges and relations):
| Dataset | #Nodes | #Edges | #Relations | #Labeled |
| --- | --- | --- | --- | --- |
| AIFB | 8,285 | 58,086 | 90 | 176 |
| AIFB-hetero | 7,262 | 48,810 | 78 | 176 |
| MUTAG | 23,644 | 148,454 | 46 | 340 |
| MUTAG-hetero | 27,163 | 148,100 | 46 | 340 |
| BGS | 333,845 | 1,832,398 | 206 | 146 |
| BGS-hetero | 94,806 | 672,884 | 96 | 146 |
| AM | 1,666,764 | 11,976,642 | 266 | 1000 |
| AM-hetero | 881,680 | 5,668,682 | 96 | 1000 |
### Dependencies
* PyTorch 1.0+
* requests
* rdflib
```
pip install requests torch rdflib pandas
```
Example code was tested with rdflib 4.2.2 and pandas 0.23.4
### Entity Classification
(all experiments use one-hot encoding as featureless input)
AIFB: accuracy 97.22% (DGL), 95.83% (paper)
```
python3 entity_classify.py -d aifb --testing --gpu 0
```
MUTAG: accuracy 73.53% (DGL), 73.23% (paper)
```
python3 entity_classify.py -d mutag --l2norm 5e-4 --n-bases 30 --testing --gpu 0
```
BGS: accuracy 93.10% (DGL), 83.10% (paper)
```
python3 entity_classify.py -d bgs --l2norm 5e-4 --n-bases 40 --testing --gpu 0
```
AM: accuracy 91.41% (DGL), 89.29% (paper)
```
python3 entity_classify.py -d am --l2norm 5e-4 --n-bases 40 --testing --gpu 0
```
"""Modeling Relational Data with Graph Convolutional Networks
Paper: https://arxiv.org/abs/1703.06103
Reference Code: https://github.com/tkipf/relational-gcn
"""
import argparse
import numpy as np
import time
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from functools import partial
import dgl.function as fn
from dgl.data.rdf import AIFB, MUTAG, BGS, AM
class RelGraphConvHetero(nn.Module):
r"""Relational graph convolution layer.
Parameters
----------
in_feat : int
Input feature size.
out_feat : int
Output feature size.
rel_names : int
Relation names.
regularizer : str
Which weight regularizer to use "basis" or "bdd"
num_bases : int, optional
Number of bases. If is none, use number of relations. Default: None.
bias : bool, optional
True if bias is added. Default: True
activation : callable, optional
Activation function. Default: None
self_loop : bool, optional
True to include self loop message. Default: False
dropout : float, optional
Dropout rate. Default: 0.0
"""
def __init__(self,
in_feat,
out_feat,
rel_names,
regularizer="basis",
num_bases=None,
bias=True,
activation=None,
self_loop=False,
dropout=0.0):
super(RelGraphConvHetero, self).__init__()
self.in_feat = in_feat
self.out_feat = out_feat
self.rel_names = rel_names
self.num_rels = len(rel_names)
self.regularizer = regularizer
self.num_bases = num_bases
if self.num_bases is None or self.num_bases > self.num_rels or self.num_bases < 0:
self.num_bases = self.num_rels
self.bias = bias
self.activation = activation
self.self_loop = self_loop
if regularizer == "basis":
# add basis weights
self.weight = nn.Parameter(th.Tensor(self.num_bases, self.in_feat, self.out_feat))
if self.num_bases < self.num_rels:
# linear combination coefficients
self.w_comp = nn.Parameter(th.Tensor(self.num_rels, self.num_bases))
nn.init.xavier_uniform_(self.weight, gain=nn.init.calculate_gain('relu'))
if self.num_bases < self.num_rels:
nn.init.xavier_uniform_(self.w_comp,
gain=nn.init.calculate_gain('relu'))
else:
raise ValueError("Only basis regularizer is supported.")
# bias
if self.bias:
self.h_bias = nn.Parameter(th.Tensor(out_feat))
nn.init.zeros_(self.h_bias)
# weight for self loop
if self.self_loop:
self.loop_weight = nn.Parameter(th.Tensor(in_feat, out_feat))
nn.init.xavier_uniform_(self.loop_weight,
gain=nn.init.calculate_gain('relu'))
self.dropout = nn.Dropout(dropout)
def basis_weight(self):
"""Message function for basis regularizer"""
if self.num_bases < self.num_rels:
# generate all weights from bases
weight = self.weight.view(self.num_bases,
self.in_feat * self.out_feat)
weight = th.matmul(self.w_comp, weight).view(
self.num_rels, self.in_feat, self.out_feat)
else:
weight = self.weight
return {self.rel_names[i] : w.squeeze(0) for i, w in enumerate(th.split(weight, 1, dim=0))}
def forward(self, g, xs):
""" Forward computation
Parameters
----------
g : DGLHeteroGraph
Input graph.
xs : list of torch.Tensor
Node feature for each node type.
Returns
-------
list of torch.Tensor
New node features for each node type.
"""
g = g.local_var()
for i, ntype in enumerate(g.ntypes):
g.nodes[ntype].data['x'] = xs[i]
ws = self.basis_weight()
funcs = {}
for i, (srctype, etype, dsttype) in enumerate(g.canonical_etypes):
g.nodes[srctype].data['h%d' % i] = th.matmul(
g.nodes[srctype].data['x'], ws[etype])
funcs[(srctype, etype, dsttype)] = (fn.copy_u('h%d' % i, 'm'), fn.mean('m', 'h'))
# message passing
g.multi_update_all(funcs, 'sum')
hs = [g.nodes[ntype].data['h'] for ntype in g.ntypes]
for i in range(len(hs)):
h = hs[i]
# apply bias and activation
if self.self_loop:
h = h + th.matmul(xs[i], self.loop_weight)
if self.bias:
h = h + self.h_bias
if self.activation:
h = self.activation(h)
h = self.dropout(h)
hs[i] = h
return hs
class RelGraphConvHeteroEmbed(nn.Module):
r"""Embedding layer for featureless heterograph."""
def __init__(self,
embed_size,
g,
bias=True,
activation=None,
self_loop=False,
dropout=0.0):
super(RelGraphConvHeteroEmbed, self).__init__()
self.embed_size = embed_size
self.g = g
self.bias = bias
self.activation = activation
self.self_loop = self_loop
# create weight embeddings for each node for each relation
self.embeds = nn.ParameterList()
for srctype, etype, dsttype in g.canonical_etypes:
embed = nn.Parameter(th.Tensor(g.number_of_nodes(srctype), self.embed_size))
nn.init.xavier_uniform_(embed, gain=nn.init.calculate_gain('relu'))
self.embeds.append(embed)
# bias
if self.bias:
self.h_bias = nn.Parameter(th.Tensor(embed_size))
nn.init.zeros_(self.h_bias)
# weight for self loop
if self.self_loop:
self.self_embeds = nn.ParameterList()
for ntype in g.ntypes:
embed = nn.Parameter(th.Tensor(g.number_of_nodes(ntype), embed_size))
nn.init.xavier_uniform_(embed,
gain=nn.init.calculate_gain('relu'))
self.self_embeds.append(embed)
self.dropout = nn.Dropout(dropout)
def forward(self):
""" Forward computation
Returns
-------
torch.Tensor
New node features.
"""
g = self.g.local_var()
funcs = {}
for i, (srctype, etype, dsttype) in enumerate(g.canonical_etypes):
g.nodes[srctype].data['embed-%d' % i] = self.embeds[i]
funcs[(srctype, etype, dsttype)] = (fn.copy_u('embed-%d' % i, 'm'), fn.mean('m', 'h'))
g.multi_update_all(funcs, 'sum')
hs = [g.nodes[ntype].data['h'] for ntype in g.ntypes]
for i in range(len(hs)):
h = hs[i]
# apply bias and activation
if self.self_loop:
h = h + self.self_embeds[i]
if self.bias:
h = h + self.h_bias
if self.activation:
h = self.activation(h)
h = self.dropout(h)
hs[i] = h
return hs
class EntityClassify(nn.Module):
def __init__(self,
g,
h_dim, out_dim,
num_bases,
num_hidden_layers=1,
dropout=0,
use_self_loop=False):
super(EntityClassify, self).__init__()
self.g = g
self.h_dim = h_dim
self.out_dim = out_dim
self.rel_names = list(set(g.etypes))
self.num_bases = None if num_bases < 0 else num_bases
self.num_hidden_layers = num_hidden_layers
self.dropout = dropout
self.use_self_loop = use_self_loop
self.embed_layer = RelGraphConvHeteroEmbed(
self.h_dim, g, activation=F.relu, self_loop=self.use_self_loop,
dropout=self.dropout)
self.layers = nn.ModuleList()
# h2h
for i in range(self.num_hidden_layers):
self.layers.append(RelGraphConvHetero(
self.h_dim, self.h_dim, self.rel_names, "basis",
self.num_bases, activation=F.relu, self_loop=self.use_self_loop,
dropout=self.dropout))
# h2o
self.layers.append(RelGraphConvHetero(
self.h_dim, self.out_dim, self.rel_names, "basis",
self.num_bases, activation=partial(F.softmax, dim=1),
self_loop=self.use_self_loop))
def forward(self):
h = self.embed_layer()
for layer in self.layers:
h = layer(self.g, h)
return h
def main(args):
# load graph data
if args.dataset == 'aifb':
dataset = AIFB()
elif args.dataset == 'mutag':
dataset = MUTAG()
elif args.dataset == 'bgs':
dataset = BGS()
elif args.dataset == 'am':
dataset = AM()
else:
raise ValueError()
g = dataset.graph
category = dataset.predict_category
num_classes = dataset.num_classes
train_idx = dataset.train_idx
test_idx = dataset.test_idx
labels = dataset.labels
category_id = len(g.ntypes)
for i, ntype in enumerate(g.ntypes):
if ntype == category:
category_id = i
# split dataset into train, validate, test
if args.validation:
val_idx = train_idx[:len(train_idx) // 5]
train_idx = train_idx[len(train_idx) // 5:]
else:
val_idx = train_idx
# check cuda
use_cuda = args.gpu >= 0 and th.cuda.is_available()
if use_cuda:
th.cuda.set_device(args.gpu)
labels = labels.cuda()
train_idx = train_idx.cuda()
test_idx = test_idx.cuda()
labels = labels.cuda()
# create model
model = EntityClassify(g,
args.n_hidden,
num_classes,
num_bases=args.n_bases,
num_hidden_layers=args.n_layers - 2,
dropout=args.dropout,
use_self_loop=args.use_self_loop)
if use_cuda:
model.cuda()
# optimizer
optimizer = th.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.l2norm)
# training loop
print("start training...")
dur = []
model.train()
for epoch in range(args.n_epochs):
optimizer.zero_grad()
if epoch > 5:
t0 = time.time()
logits = model()[category_id]
loss = F.cross_entropy(logits[train_idx], labels[train_idx])
loss.backward()
optimizer.step()
t1 = time.time()
if epoch > 5:
dur.append(t1 - t0)
train_acc = th.sum(logits[train_idx].argmax(dim=1) == labels[train_idx]).item() / len(train_idx)
val_loss = F.cross_entropy(logits[val_idx], labels[val_idx])
val_acc = th.sum(logits[val_idx].argmax(dim=1) == labels[val_idx]).item() / len(val_idx)
print("Epoch {:05d} | Train Acc: {:.4f} | Train Loss: {:.4f} | Valid Acc: {:.4f} | Valid loss: {:.4f} | Time: {:.4f}".
format(epoch, train_acc, loss.item(), val_acc, val_loss.item(), np.average(dur)))
print()
model.eval()
logits = model.forward()[category_id]
test_loss = F.cross_entropy(logits[test_idx], labels[test_idx])
test_acc = th.sum(logits[test_idx].argmax(dim=1) == labels[test_idx]).item() / len(test_idx)
print("Test Acc: {:.4f} | Test loss: {:.4f}".format(test_acc, test_loss.item()))
print()
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='RGCN')
parser.add_argument("--dropout", type=float, default=0,
help="dropout probability")
parser.add_argument("--n-hidden", type=int, default=16,
help="number of hidden units")
parser.add_argument("--gpu", type=int, default=-1,
help="gpu")
parser.add_argument("--lr", type=float, default=1e-2,
help="learning rate")
parser.add_argument("--n-bases", type=int, default=-1,
help="number of filter weight matrices, default: -1 [use all]")
parser.add_argument("--n-layers", type=int, default=2,
help="number of propagation rounds")
parser.add_argument("-e", "--n-epochs", type=int, default=50,
help="number of training epochs")
parser.add_argument("-d", "--dataset", type=str, required=True,
help="dataset to use")
parser.add_argument("--l2norm", type=float, default=0,
help="l2 norm coef")
parser.add_argument("--use-self-loop", default=False, action='store_true',
help="include self feature as a special relation")
fp = parser.add_mutually_exclusive_group(required=False)
fp.add_argument('--validation', dest='validation', action='store_true')
fp.add_argument('--testing', dest='validation', action='store_false')
parser.set_defaults(validation=True)
args = parser.parse_args()
print(args)
main(args)
...@@ -32,6 +32,11 @@ BGS: accuracy 82.76% (DGL), 83.10% (paper) ...@@ -32,6 +32,11 @@ BGS: accuracy 82.76% (DGL), 83.10% (paper)
python3 entity_classify.py -d bgs --l2norm 5e-4 --n-bases 40 --testing --gpu 0 --relabel python3 entity_classify.py -d bgs --l2norm 5e-4 --n-bases 40 --testing --gpu 0 --relabel
``` ```
AM: accuracy 87.37% (DGL), 89.29% (paper)
```
python3 entity_classify.py -d am --n-bases=40 --n-hidden=10 --l2norm=5e-4 --testing
```
### Link Prediction ### Link Prediction
FB15k-237: MRR 0.151 (DGL), 0.158 (paper) FB15k-237: MRR 0.151 (DGL), 0.158 (paper)
``` ```
......
...@@ -327,7 +327,7 @@ def _load_data(dataset_str='aifb', dataset_path=None): ...@@ -327,7 +327,7 @@ def _load_data(dataset_str='aifb', dataset_path=None):
train_file = os.path.join(dataset_path, 'trainingSet.tsv') train_file = os.path.join(dataset_path, 'trainingSet.tsv')
test_file = os.path.join(dataset_path, 'testSet.tsv') test_file = os.path.join(dataset_path, 'testSet.tsv')
if dataset_str == 'am': if dataset_str == 'am':
label_header = 'label_category' label_header = 'label_cateogory'
nodes_header = 'proxy' nodes_header = 'proxy'
elif dataset_str == 'aifb': elif dataset_str == 'aifb':
label_header = 'label_affiliation' label_header = 'label_affiliation'
......
"""RDF datasets
Datasets from "A Collection of Benchmark Datasets for
Systematic Evaluations of Machine Learning on
the Semantic Web"
"""
import os
from collections import OrderedDict
import itertools
import rdflib as rdf
import abc
import re
import networkx as nx
import numpy as np
import dgl
import dgl.backend as F
from .utils import download, extract_archive, get_download_dir, _get_dgl_url
__all__ = ['AIFB', 'MUTAG', 'BGS', 'AM']
class Entity:
"""Class for entities
Parameters
----------
id : str
ID of this entity
cls : str
Type of this entity
"""
def __init__(self, id, cls):
self.id = id
self.cls = cls
def __str__(self):
return '{}/{}'.format(self.cls, self.id)
class Relation:
"""Class for relations
Parameters
----------
cls : str
Type of this relation
"""
def __init__(self, cls):
self.cls = cls
def __str__(self):
return str(self.cls)
class RDFGraphDataset:
"""Base graph dataset class from RDF tuples.
To derive from this, implement the following abstract methods:
* ``parse_entity``
* ``parse_relation``
* ``process_tuple``
* ``process_idx_file_line``
* ``predict_category``
Preprocessed graph and other data will be cached in the download folder
to speedup data loading.
The dataset should contain a "trainingSet.tsv" and a "testSet.tsv" file
for training and testing samples.
Attributes
----------
graph : dgl.DGLHeteroGraph
Graph structure
num_classes : int
Number of classes to predict
predict_category : str
The entity category (node type) that has labels for prediction
train_idx : Tensor
Entity IDs for training. All IDs are local IDs w.r.t. to ``predict_category``.
test_idx : Tensor
Entity IDs for testing. All IDs are local IDs w.r.t. to ``predict_category``.
labels : Tensor
All the labels of the entities in ``predict_category``
Parameters
----------
url : str or path
URL to download the raw dataset.
name : str
Name of the dataset
force_reload : bool, optional
If true, force load and process from raw data. Ignore cached pre-processed data.
print_every : int, optional
Log for every X tuples.
insert_reverse : bool, optional
If true, add reverse edge and reverse relations to the final graph.
"""
def __init__(self, url, name,
force_reload=False,
print_every=10000,
insert_reverse=True):
download_dir = get_download_dir()
zip_file_path = os.path.join(download_dir, '{}.zip'.format(name))
download(url, path=zip_file_path)
self._dir = os.path.join(download_dir, name)
extract_archive(zip_file_path, self._dir)
self._print_every = print_every
self._insert_reverse = insert_reverse
if not force_reload and self.has_cache():
print('Found cached graph. Load cache ...')
self.load_cache()
else:
raw_tuples = self.load_raw_tuples()
self.process_raw_tuples(raw_tuples)
print('#Training samples:', len(self.train_idx))
print('#Testing samples:', len(self.test_idx))
print('#Classes:', self.num_classes)
print('Predict category:', self.predict_category)
def load_raw_tuples(self):
raw_rdf_graphs = []
for i, filename in enumerate(os.listdir(self._dir)):
fmt = None
if filename.endswith('nt'):
fmt = 'nt'
elif filename.endswith('n3'):
fmt = 'n3'
if fmt is None:
continue
g = rdf.Graph()
print('Parsing file %s ...' % filename)
g.parse(os.path.join(self._dir, filename), format=fmt)
raw_rdf_graphs.append(g)
return itertools.chain(*raw_rdf_graphs)
def process_raw_tuples(self, raw_tuples):
mg = nx.MultiDiGraph()
ent_classes = OrderedDict()
rel_classes = OrderedDict()
entities = OrderedDict()
src = []
dst = []
ntid = []
etid = []
for i, (sbj, pred, obj) in enumerate(raw_tuples):
if i % self._print_every == 0:
print('Processed %d tuples, found %d valid tuples.' % (i, len(src)))
sbjent = self.parse_entity(sbj)
rel = self.parse_relation(pred)
objent = self.parse_entity(obj)
processed = self.process_tuple((sbj, pred, obj), sbjent, rel, objent)
if processed is None:
# ignored
continue
# meta graph
sbjclsid = _get_id(ent_classes, sbjent.cls)
objclsid = _get_id(ent_classes, objent.cls)
relclsid = _get_id(rel_classes, rel.cls)
mg.add_edge(sbjent.cls, objent.cls, key=rel.cls)
if self._insert_reverse:
mg.add_edge(objent.cls, sbjent.cls, key='rev-%s' % rel.cls)
# instance graph
src_id = _get_id(entities, str(sbjent))
if len(entities) > len(ntid): # found new entity
ntid.append(sbjclsid)
dst_id = _get_id(entities, str(objent))
if len(entities) > len(ntid): # found new entity
ntid.append(objclsid)
src.append(src_id)
dst.append(dst_id)
etid.append(relclsid)
src = np.array(src)
dst = np.array(dst)
ntid = np.array(ntid)
etid = np.array(etid)
ntypes = list(ent_classes.keys())
etypes = list(rel_classes.keys())
# add reverse edge with reverse relation
if self._insert_reverse:
print('Adding reverse edges ...')
newsrc = np.hstack([src, dst])
newdst = np.hstack([dst, src])
src = newsrc
dst = newdst
etid = np.hstack([etid, etid + len(etypes)])
etypes.extend(['rev-%s' % t for t in etypes])
self.build_graph(mg, src, dst, ntid, etid, ntypes, etypes)
print('Load training/validation/testing split ...')
idmap = F.asnumpy(self.graph.nodes[self.predict_category].data[dgl.NID])
glb2lcl = {glbid : lclid for lclid, glbid in enumerate(idmap)}
def findidfn(ent):
if ent not in entities:
return None
else:
return glb2lcl[entities[ent]]
self.load_data_split(findidfn)
self.save_cache(mg, src, dst, ntid, etid, ntypes, etypes)
def build_graph(self, mg, src, dst, ntid, etid, ntypes, etypes):
# create homo graph
print('Creating one whole graph ...')
g = dgl.graph((src, dst))
g.ndata[dgl.NTYPE] = F.tensor(ntid)
g.edata[dgl.ETYPE] = F.tensor(etid)
print('Total #nodes:', g.number_of_nodes())
print('Total #edges:', g.number_of_edges())
# convert to heterograph
print('Convert to heterograph ...')
hg = dgl.to_hetero(g,
ntypes,
etypes,
metagraph=mg)
print('#Node types:', len(hg.ntypes))
print('#Canonical edge types:', len(hg.etypes))
print('#Unique edge type names:', len(set(hg.etypes)))
self.graph = hg
def save_cache(self, mg, src, dst, ntid, etid, ntypes, etypes):
nx.write_gpickle(mg, os.path.join(self._dir, 'cached_mg.gpickle'))
np.save(os.path.join(self._dir, 'cached_src.npy'), src)
np.save(os.path.join(self._dir, 'cached_dst.npy'), dst)
np.save(os.path.join(self._dir, 'cached_ntid.npy'), ntid)
np.save(os.path.join(self._dir, 'cached_etid.npy'), etid)
save_strlist(os.path.join(self._dir, 'cached_ntypes.txt'), ntypes)
save_strlist(os.path.join(self._dir, 'cached_etypes.txt'), etypes)
np.save(os.path.join(self._dir, 'cached_train_idx.npy'), F.asnumpy(self.train_idx))
np.save(os.path.join(self._dir, 'cached_test_idx.npy'), F.asnumpy(self.test_idx))
np.save(os.path.join(self._dir, 'cached_labels.npy'), F.asnumpy(self.labels))
def has_cache(self):
return (os.path.exists(os.path.join(self._dir, 'cached_mg.gpickle'))
and os.path.exists(os.path.join(self._dir, 'cached_src.npy'))
and os.path.exists(os.path.join(self._dir, 'cached_dst.npy'))
and os.path.exists(os.path.join(self._dir, 'cached_ntid.npy'))
and os.path.exists(os.path.join(self._dir, 'cached_etid.npy'))
and os.path.exists(os.path.join(self._dir, 'cached_ntypes.txt'))
and os.path.exists(os.path.join(self._dir, 'cached_etypes.txt'))
and os.path.exists(os.path.join(self._dir, 'cached_train_idx.npy'))
and os.path.exists(os.path.join(self._dir, 'cached_test_idx.npy'))
and os.path.exists(os.path.join(self._dir, 'cached_labels.npy')))
def load_cache(self):
mg = nx.read_gpickle(os.path.join(self._dir, 'cached_mg.gpickle'))
src = np.load(os.path.join(self._dir, 'cached_src.npy'))
dst = np.load(os.path.join(self._dir, 'cached_dst.npy'))
ntid = np.load(os.path.join(self._dir, 'cached_ntid.npy'))
etid = np.load(os.path.join(self._dir, 'cached_etid.npy'))
ntypes = load_strlist(os.path.join(self._dir, 'cached_ntypes.txt'))
etypes = load_strlist(os.path.join(self._dir, 'cached_etypes.txt'))
self.train_idx = F.tensor(np.load(os.path.join(self._dir, 'cached_train_idx.npy')))
self.test_idx = F.tensor(np.load(os.path.join(self._dir, 'cached_test_idx.npy')))
labels = np.load(os.path.join(self._dir, 'cached_labels.npy'))
self.num_classes = labels.max() + 1
self.labels = F.tensor(labels)
self.build_graph(mg, src, dst, ntid, etid, ntypes, etypes)
def load_data_split(self, ent2id):
label_dict = {}
labels = np.zeros((self.graph.number_of_nodes(self.predict_category),)) - 1
train_idx = self.parse_idx_file(
os.path.join(self._dir, 'trainingSet.tsv'),
ent2id, label_dict, labels)
test_idx = self.parse_idx_file(
os.path.join(self._dir, 'testSet.tsv'),
ent2id, label_dict, labels)
self.train_idx = F.tensor(train_idx)
self.test_idx = F.tensor(test_idx)
self.labels = F.tensor(labels).long()
self.num_classes = len(label_dict)
def parse_idx_file(self, filename, ent2id, label_dict, labels):
idx = []
with open(filename, 'r') as f:
for i, line in enumerate(f):
if i == 0:
continue # first line is the header
sample, label = self.process_idx_file_line(line)
#person, _, label = line.strip().split('\t')
ent = self.parse_entity(sample)
entid = ent2id(str(ent))
if entid is None:
print('Warning: entity "%s" does not have any valid links associated. Ignored.' % str(ent))
else:
idx.append(entid)
lblid = _get_id(label_dict, label)
labels[entid] = lblid
return idx
@abc.abstractmethod
def parse_entity(self, term):
"""Parse one entity from an RDF term.
Return None if the term does not represent a valid entity and the
whole tuple should be ignored.
Parameters
----------
term : rdflib.term.Identifier
RDF term
Returns
-------
Entity or None
An entity.
"""
pass
@abc.abstractmethod
def parse_relation(self, term):
"""Parse one relation from an RDF term.
Return None if the term does not represent a valid relation and the
whole tuple should be ignored.
Parameters
----------
term : rdflib.term.Identifier
RDF term
Returns
-------
Relation or None
A relation
"""
pass
@abc.abstractmethod
def process_tuple(self, raw_tuple, sbj, rel, obj):
"""Process the tuple.
Return (Entity, Relation, Entity) tuple for as the final tuple.
Return None if the tuple should be ignored.
Parameters
----------
raw_tuple : tuple of rdflib.term.Identifier
(subject, predicate, object) tuple
sbj : Entity
Subject entity
rel : Relation
Relation
obj : Entity
Object entity
Returns
-------
(Entity, Relation, Entity)
The final tuple or None if should be ignored
"""
pass
@abc.abstractmethod
def process_idx_file_line(self, line):
"""Process one line of ``trainingSet.tsv`` or ``testSet.tsv``.
Parameters
----------
line : str
One line of the file
Returns
-------
(str, str)
One sample and its label
"""
pass
@property
@abc.abstractmethod
def predict_category(self):
"""Return the category name that has labels."""
pass
def _get_id(dict, key):
id = dict.get(key, None)
if id is None:
id = len(dict)
dict[key] = id
return id
def save_strlist(filename, strlist):
with open(filename, 'w') as f:
for s in strlist:
f.write(s + '\n')
def load_strlist(filename):
with open(filename, 'r') as f:
ret = []
for line in f:
ret.append(line.strip())
return ret
class AIFB(RDFGraphDataset):
"""AIFB dataset.
Examples
--------
>>> dataset = dgl.data.rdf.AIFB()
>>> print(dataset.graph)
"""
employs = rdf.term.URIRef("http://swrc.ontoware.org/ontology#employs")
affiliation = rdf.term.URIRef("http://swrc.ontoware.org/ontology#affiliation")
entity_prefix = 'http://www.aifb.uni-karlsruhe.de/'
relation_prefix = 'http://swrc.ontoware.org/'
def __init__(self,
force_reload=False,
print_every=10000,
insert_reverse=True):
url = _get_dgl_url('dataset/rdf/aifb-hetero.zip')
name = 'aifb-hetero'
super(AIFB, self).__init__(url, name,
force_reload=force_reload,
print_every=print_every,
insert_reverse=insert_reverse)
def parse_entity(self, term):
if isinstance(term, rdf.Literal):
return Entity(id=str(term), cls="_Literal")
if isinstance(term, rdf.BNode):
return None
entstr = str(term)
if entstr.startswith(self.entity_prefix):
sp = entstr.split('/')
return Entity(id=sp[5], cls=sp[3])
else:
return None
def parse_relation(self, term):
if term == self.employs or term == self.affiliation:
return None
relstr = str(term)
if relstr.startswith(self.relation_prefix):
return Relation(cls=relstr.split('/')[3])
else:
relstr = relstr.split('/')[-1]
return Relation(cls=relstr)
def process_tuple(self, raw_tuple, sbj, rel, obj):
if sbj is None or rel is None or obj is None:
return None
return (sbj, rel, obj)
def process_idx_file_line(self, line):
person, _, label = line.strip().split('\t')
return person, label
@property
def predict_category(self):
return 'Personen'
class MUTAG(RDFGraphDataset):
"""MUTAG dataset.
Examples
--------
>>> dataset = dgl.data.rdf.MUTAG()
>>> print(dataset.graph)
"""
d_entity = re.compile("d[0-9]")
bond_entity = re.compile("bond[0-9]")
is_mutagenic = rdf.term.URIRef("http://dl-learner.org/carcinogenesis#isMutagenic")
rdf_type = rdf.term.URIRef("http://www.w3.org/1999/02/22-rdf-syntax-ns#type")
rdf_subclassof = rdf.term.URIRef("http://www.w3.org/2000/01/rdf-schema#subClassOf")
rdf_domain = rdf.term.URIRef("http://www.w3.org/2000/01/rdf-schema#domain")
entity_prefix = 'http://dl-learner.org/carcinogenesis#'
relation_prefix = entity_prefix
def __init__(self,
force_reload=False,
print_every=10000,
insert_reverse=True):
url = _get_dgl_url('dataset/rdf/mutag-hetero.zip')
name = 'mutag-hetero'
super(MUTAG, self).__init__(url, name,
force_reload=force_reload,
print_every=print_every,
insert_reverse=insert_reverse)
def parse_entity(self, term):
if isinstance(term, rdf.Literal):
return Entity(id=str(term), cls="_Literal")
elif isinstance(term, rdf.BNode):
return None
entstr = str(term)
if entstr.startswith(self.entity_prefix):
inst = entstr[len(self.entity_prefix):]
if self.d_entity.match(inst):
cls = 'd'
elif self.bond_entity.match(inst):
cls = 'bond'
else:
cls = None
return Entity(id=inst, cls=cls)
else:
return None
def parse_relation(self, term):
if term == self.is_mutagenic:
return None
relstr = str(term)
if relstr.startswith(self.relation_prefix):
cls = relstr[len(self.relation_prefix):]
return Relation(cls=cls)
else:
relstr = relstr.split('/')[-1]
return Relation(cls=relstr)
def process_tuple(self, raw_tuple, sbj, rel, obj):
if sbj is None or rel is None or obj is None:
return None
if not raw_tuple[1].startswith('http://dl-learner.org/carcinogenesis#'):
obj.cls = 'SCHEMA'
if sbj.cls is None:
sbj.cls = 'SCHEMA'
if obj.cls is None:
obj.cls = rel.cls
assert sbj.cls is not None and obj.cls is not None
return (sbj, rel, obj)
def process_idx_file_line(self, line):
bond, _, label = line.strip().split('\t')
return bond, label
@property
def predict_category(self):
return 'd'
class BGS(RDFGraphDataset):
"""BGS dataset.
BGS namespace convention:
http://data.bgs.ac.uk/(ref|id)/<Major Concept>/<Sub Concept>/INSTANCE
We ignored all literal nodes and the relations connecting them in the
output graph. We also ignored the relation used to mark whether a
term is CURRENT or DEPRECATED.
Examples
--------
>>> dataset = dgl.data.rdf.BGS()
>>> print(dataset.graph)
"""
lith = rdf.term.URIRef("http://data.bgs.ac.uk/ref/Lexicon/hasLithogenesis")
entity_prefix = 'http://data.bgs.ac.uk/'
status_prefix = 'http://data.bgs.ac.uk/ref/CurrentStatus'
relation_prefix = 'http://data.bgs.ac.uk/ref'
def __init__(self,
force_reload=False,
print_every=10000,
insert_reverse=True):
url = _get_dgl_url('dataset/rdf/bgs-hetero.zip')
name = 'bgs-hetero'
super(BGS, self).__init__(url, name,
force_reload=force_reload,
print_every=print_every,
insert_reverse=insert_reverse)
def parse_entity(self, term):
if isinstance(term, rdf.Literal):
return None
elif isinstance(term, rdf.BNode):
return None
entstr = str(term)
if entstr.startswith(self.status_prefix):
return None
if entstr.startswith(self.entity_prefix):
sp = entstr.split('/')
if len(sp) != 7:
return None
# instance
cls = '%s/%s' % (sp[4], sp[5])
inst = sp[6]
return Entity(id=inst, cls=cls)
else:
return None
def parse_relation(self, term):
if term == self.lith:
return None
relstr = str(term)
if relstr.startswith(self.relation_prefix):
sp = relstr.split('/')
if len(sp) < 6:
return None
assert len(sp) == 6, relstr
cls = '%s/%s' % (sp[4], sp[5])
return Relation(cls=cls)
else:
relstr = relstr.replace('.', '_')
return Relation(cls=relstr)
def process_tuple(self, raw_tuple, sbj, rel, obj):
if sbj is None or rel is None or obj is None:
return None
return (sbj, rel, obj)
def process_idx_file_line(self, line):
_, rock, label = line.strip().split('\t')
return rock, label
@property
def predict_category(self):
return 'Lexicon/NamedRockUnit'
class AM(RDFGraphDataset):
"""AM dataset.
Namespace convention:
Instance: http://purl.org/collections/nl/am/<type>-<id>
Relation: http://purl.org/collections/nl/am/<name>
We ignored all literal nodes and the relations connecting them in the
output graph.
Examples
--------
>>> dataset = dgl.data.rdf.AM()
>>> print(dataset.graph)
"""
objectCategory = rdf.term.URIRef("http://purl.org/collections/nl/am/objectCategory")
material = rdf.term.URIRef("http://purl.org/collections/nl/am/material")
entity_prefix = 'http://purl.org/collections/nl/am/'
relation_prefix = entity_prefix
def __init__(self,
force_reload=False,
print_every=10000,
insert_reverse=True):
url = _get_dgl_url('dataset/rdf/am-hetero.zip')
name = 'am-hetero'
super(AM, self).__init__(url, name,
force_reload=force_reload,
print_every=print_every,
insert_reverse=insert_reverse)
def parse_entity(self, term):
if isinstance(term, rdf.Literal):
return None
elif isinstance(term, rdf.BNode):
return Entity(id=str(term), cls='_BNode')
entstr = str(term)
if entstr.startswith(self.entity_prefix):
sp = entstr.split('/')
assert len(sp) == 7, entstr
spp = sp[6].split('-')
if len(spp) == 2:
# instance
cls, inst = spp
else:
cls = 'TYPE'
inst = spp
return Entity(id=inst, cls=cls)
else:
return None
def parse_relation(self, term):
if term == self.objectCategory or term == self.material:
return None
relstr = str(term)
if relstr.startswith(self.relation_prefix):
sp = relstr.split('/')
assert len(sp) == 7, relstr
cls = sp[6]
return Relation(cls=cls)
else:
relstr = relstr.replace('.', '_')
return Relation(cls=relstr)
def process_tuple(self, raw_tuple, sbj, rel, obj):
if sbj is None or rel is None or obj is None:
return None
return (sbj, rel, obj)
def process_idx_file_line(self, line):
proxy, _, label = line.strip().split('\t')
return proxy, label
@property
def predict_category(self):
return 'proxy'
if __name__ == '__main__':
AIFB()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment