Unverified Commit 5152a879 authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

[Data] Utility function and class for converting a dataset for node prediction (#3695)

* add ut

* add doc link

* install dep

* fix  ci

* fix ut; more comments

* remove deprecated attributes in rdf datasets; fix label feature name

* address comments

* fix ut for other frameworks
parent 56b5d0e5
...@@ -227,6 +227,13 @@ Fake news dataset ...@@ -227,6 +227,13 @@ Fake news dataset
.. autoclass:: FakeNewsDataset .. autoclass:: FakeNewsDataset
:members: __getitem__, __len__ :members: __getitem__, __len__
Dataset adapters
```````````````````````````````````
.. autoclass:: AsNodePredDataset
:members: __getitem__, __len__
Utilities Utilities
----------------- -----------------
...@@ -241,6 +248,7 @@ Utilities ...@@ -241,6 +248,7 @@ Utilities
utils.load_labels utils.load_labels
utils.save_info utils.save_info
utils.load_info utils.load_info
utils.add_nodepred_split
.. autoclass:: dgl.data.utils.Subset .. autoclass:: dgl.data.utils.Subset
:members: __getitem__, __len__ :members: __getitem__, __len__
...@@ -30,6 +30,7 @@ from .rdf import AIFBDataset, MUTAGDataset, BGSDataset, AMDataset ...@@ -30,6 +30,7 @@ from .rdf import AIFBDataset, MUTAGDataset, BGSDataset, AMDataset
from .fraud import FraudDataset, FraudYelpDataset, FraudAmazonDataset from .fraud import FraudDataset, FraudYelpDataset, FraudAmazonDataset
from .fakenews import FakeNewsDataset from .fakenews import FakeNewsDataset
from .csv_dataset import DGLCSVDataset from .csv_dataset import DGLCSVDataset
from .adapter import AsNodePredDataset
def register_data_args(parser): def register_data_args(parser):
parser.add_argument( parser.add_argument(
......
"""Dataset adapters for re-purposing a dataset for a different kind of training task."""
import os
import json
from .dgl_dataset import DGLDataset
from . import utils
__all__ = ['AsNodePredDataset']
class AsNodePredDataset(DGLDataset):
"""Repurpose a dataset for a standard semi-supervised transductive
node prediction task.
The class converts a given dataset into a new dataset object that:
- Contains only one graph, accessible from ``dataset[0]``.
- The graph stores:
- Node labels in ``g.ndata['label']``.
- Train/val/test masks in ``g.ndata['train_mask']``, ``g.ndata['val_mask']``,
and ``g.ndata['test_mask']`` respectively.
- In addition, the dataset contains the following attributes:
- ``num_classes``, the number of classes to predict.
If the input dataset contains heterogeneous graphs, users need to specify the
``target_ntype`` argument to indicate which node type to make predictions for.
In this case:
- Node labels are stored in ``g.nodes[target_ntype].data['label']``.
- Training masks are stored in ``g.nodes[target_ntype].data['train_mask']``.
So do validation and test masks.
The class will keep only the first graph in the provided dataset and
generate train/val/test masks according to the given spplit ratio. The generated
masks will be cached to disk for fast re-loading. If the provided split ratio
differs from the cached one, it will re-process the dataset properly.
Parameters
----------
dataset : DGLDataset
The dataset to be converted.
split_ratio : (float, float, float), optional
Split ratios for training, validation and test sets. Must sum to one.
target_ntype : str, optional
The node type to add split mask for.
Attributes
----------
num_classes : int
Number of classes to predict.
Examples
--------
>>> ds = dgl.data.AmazonCoBuyComputerDataset()
>>> print(ds)
Dataset("amazon_co_buy_computer", num_graphs=1, save_path=...)
>>> new_ds = dgl.data.AsNodePredDataset(ds, [0.8, 0.1, 0.1])
>>> print(new_ds)
Dataset("amazon_co_buy_computer-as-nodepred", num_graphs=1, save_path=...)
>>> print('train_mask' in new_ds[0].ndata)
True
"""
def __init__(self,
dataset,
split_ratio=[0.8, 0.1, 0.1],
target_ntype=None,
**kwargs):
self.g = dataset[0].clone()
self.split_ratio = split_ratio
self.target_ntype = target_ntype
self.num_classes = dataset.num_classes
super().__init__(dataset.name + '-as-nodepred', **kwargs)
def process(self):
if 'label' not in self.g.nodes[self.target_ntype].data:
raise ValueError("Missing node labels. Make sure labels are stored "
"under name 'label'.")
if self.verbose:
print('Generating train/val/test masks...')
utils.add_nodepred_split(self, self.split_ratio, self.target_ntype)
def has_cache(self):
return os.path.isfile(os.path.join(self.save_path, 'graph.bin'))
def load(self):
with open(os.path.join(self.save_path, 'info.json'), 'r') as f:
info = json.load(f)
if (info['split_ratio'] != self.split_ratio
or info['target_ntype'] != self.target_ntype):
raise ValueError('Provided split ratio is different from the cached file. '
'Re-process the dataset.')
self.split_ratio = info['split_ratio']
self.target_ntype = info['target_ntype']
self.num_classes = info['num_classes']
gs, _ = utils.load_graphs(os.path.join(self.save_path, 'graph.bin'))
self.g = gs[0]
def save(self):
utils.save_graphs(os.path.join(self.save_path, 'graph.bin'), [self.g])
with open(os.path.join(self.save_path, 'info.json'), 'w') as f:
json.dump({
'split_ratio' : self.split_ratio,
'target_ntype' : self.target_ntype,
'num_classes' : self.num_classes}, f)
def __getitem__(self, idx):
return self.g
def __len__(self):
return 1
...@@ -57,16 +57,17 @@ class DGLDataset(object): ...@@ -57,16 +57,17 @@ class DGLDataset(object):
name : str name : str
The dataset name The dataset name
raw_dir : str raw_dir : str
Raw file directory contains the input data folder Directory to store all the downloaded raw datasets.
raw_path : str raw_path : str
Directory contains the input data files. Path to the downloaded raw dataset folder. An alias for
Default : ``os.path.join(self.raw_dir, self.name)`` ``os.path.join(self.raw_dir, self.name)``.
save_dir : str save_dir : str
Directory to save the processed dataset Directory to save all the processed datasets.
save_path : str save_path : str
File path to save the processed dataset Path to the processed dataset folder. An alias for
``os.path.join(self.save_dir, self.name)``.
verbose : bool verbose : bool
Whether to print information Whether to print more runtime information.
hash : str hash : str
Hash value for the dataset and the setting. Hash value for the dataset and the setting.
""" """
...@@ -123,10 +124,11 @@ class DGLDataset(object): ...@@ -123,10 +124,11 @@ class DGLDataset(object):
""" """
pass pass
@abc.abstractmethod
def process(self): def process(self):
r"""Overwrite to realize your own logic of processing the input data. r"""Overwrite to realize your own logic of processing the input data.
""" """
raise NotImplementedError pass
def has_cache(self): def has_cache(self):
r"""Overwrite to realize your own logic of r"""Overwrite to realize your own logic of
...@@ -138,9 +140,11 @@ class DGLDataset(object): ...@@ -138,9 +140,11 @@ class DGLDataset(object):
@retry_method_with_fix(download) @retry_method_with_fix(download)
def _download(self): def _download(self):
r"""Download dataset by calling ``self.download()`` if the dataset does not exists under ``self.raw_path``. """Download dataset by calling ``self.download()``
By default ``self.raw_path = os.path.join(self.raw_dir, self.name)`` if the dataset does not exists under ``self.raw_path``.
One can overwrite ``raw_path()`` function to change the path.
By default ``self.raw_path = os.path.join(self.raw_dir, self.name)``
One can overwrite ``raw_path()`` function to change the path.
""" """
if os.path.exists(self.raw_path): # pragma: no cover if os.path.exists(self.raw_path): # pragma: no cover
return return
...@@ -149,14 +153,18 @@ class DGLDataset(object): ...@@ -149,14 +153,18 @@ class DGLDataset(object):
self.download() self.download()
def _load(self): def _load(self):
r"""Entry point from __init__ to load the dataset. """Entry point from __init__ to load the dataset.
if the cache exists:
Load the dataset from saved dgl graph and information files. If cache exists:
If loadin process fails, re-download and process the dataset.
else: - Load the dataset from saved dgl graph and information files.
1. Download the dataset if needed. - If loadin process fails, re-download and process the dataset.
2. Process the dataset and build the dgl graph.
3. Save the processed dataset into files. else:
- Download the dataset if needed.
- Process the dataset and build the dgl graph.
- Save the processed dataset into files.
""" """
load_flag = not self._force_reload and self.has_cache() load_flag = not self._force_reload and self.has_cache()
...@@ -255,6 +263,10 @@ class DGLDataset(object): ...@@ -255,6 +263,10 @@ class DGLDataset(object):
r"""The number of examples in the dataset.""" r"""The number of examples in the dataset."""
pass pass
def __repr__(self):
return f'Dataset("{self.name}", num_graphs={len(self)},' + \
f' save_path={self.save_path})'
class DGLBuiltinDataset(DGLDataset): class DGLBuiltinDataset(DGLDataset):
r"""The Basic DGL Builtin Dataset. r"""The Basic DGL Builtin Dataset.
......
...@@ -16,9 +16,9 @@ import dgl ...@@ -16,9 +16,9 @@ import dgl
import dgl.backend as F import dgl.backend as F
from .dgl_dataset import DGLBuiltinDataset from .dgl_dataset import DGLBuiltinDataset
from .utils import save_graphs, load_graphs, save_info, load_info, _get_dgl_url from .utils import save_graphs, load_graphs, save_info, load_info, _get_dgl_url
from .utils import generate_mask_tensor, idx2mask, deprecate_property, deprecate_class from .utils import generate_mask_tensor, idx2mask
__all__ = ['AIFB', 'MUTAG', 'BGS', 'AM', 'AIFBDataset', 'MUTAGDataset', 'BGSDataset', 'AMDataset'] __all__ = ['AIFBDataset', 'MUTAGDataset', 'BGSDataset', 'AMDataset']
# Dictionary for renaming reserved node/edge type names to the ones # Dictionary for renaming reserved node/edge type names to the ones
# that are allowed by nn.Module. # that are allowed by nn.Module.
...@@ -72,18 +72,10 @@ class RDFGraphDataset(DGLBuiltinDataset): ...@@ -72,18 +72,10 @@ class RDFGraphDataset(DGLBuiltinDataset):
Attributes Attributes
---------- ----------
graph : dgl.DGLraph
Graph structure
num_classes : int num_classes : int
Number of classes to predict Number of classes to predict
predict_category : str predict_category : str
The entity category (node type) that has labels for prediction The entity category (node type) that has labels for prediction
train_idx : Tensor
Entity IDs for training. All IDs are local IDs w.r.t. to ``predict_category``.
test_idx : Tensor
Entity IDs for testing. All IDs are local IDs w.r.t. to ``predict_category``.
labels : Tensor
All the labels of the entities in ``predict_category``
Parameters Parameters
---------- ----------
...@@ -243,14 +235,11 @@ class RDFGraphDataset(DGLBuiltinDataset): ...@@ -243,14 +235,11 @@ class RDFGraphDataset(DGLBuiltinDataset):
test_mask = generate_mask_tensor(test_mask) test_mask = generate_mask_tensor(test_mask)
self._hg.nodes[self.predict_category].data['train_mask'] = train_mask self._hg.nodes[self.predict_category].data['train_mask'] = train_mask
self._hg.nodes[self.predict_category].data['test_mask'] = test_mask self._hg.nodes[self.predict_category].data['test_mask'] = test_mask
# TODO(minjie): Deprecate 'labels', use 'label' for consistency.
self._hg.nodes[self.predict_category].data['labels'] = labels self._hg.nodes[self.predict_category].data['labels'] = labels
self._hg.nodes[self.predict_category].data['label'] = labels
self._num_classes = num_classes self._num_classes = num_classes
# save for compatability
self._train_idx = F.tensor(train_idx)
self._test_idx = F.tensor(test_idx)
self._labels = labels
def build_graph(self, mg, src, dst, ntid, etid, ntypes, etypes): def build_graph(self, mg, src, dst, ntid, etid, ntypes, etypes):
"""Build the graphs """Build the graphs
...@@ -411,14 +400,10 @@ class RDFGraphDataset(DGLBuiltinDataset): ...@@ -411,14 +400,10 @@ class RDFGraphDataset(DGLBuiltinDataset):
self._num_classes = info['num_classes'] self._num_classes = info['num_classes']
self._predict_category = info['predict_category'] self._predict_category = info['predict_category']
self._hg = graphs[0] self._hg = graphs[0]
train_mask = self._hg.nodes[self.predict_category].data['train_mask'] # For backward compatibility
test_mask = self._hg.nodes[self.predict_category].data['test_mask'] if 'label' not in self._hg.nodes[self.predict_category].data:
self._labels = self._hg.nodes[self.predict_category].data['labels'] self._hg.nodes[self.predict_category].data['label'] = \
self._hg.nodes[self.predict_category].data['labels']
train_idx = F.nonzero_1d(train_mask)
test_idx = F.nonzero_1d(test_mask)
self._train_idx = train_idx
self._test_idx = test_idx
def __getitem__(self, idx): def __getitem__(self, idx):
r"""Gets the graph object r"""Gets the graph object
...@@ -434,11 +419,6 @@ class RDFGraphDataset(DGLBuiltinDataset): ...@@ -434,11 +419,6 @@ class RDFGraphDataset(DGLBuiltinDataset):
def save_name(self): def save_name(self):
return self.name + '_dgl_graph' return self.name + '_dgl_graph'
@property
def graph(self):
deprecate_property('dataset.graph', 'hg = dataset[0]')
return self._hg
@property @property
def predict_category(self): def predict_category(self):
return self._predict_category return self._predict_category
...@@ -447,21 +427,6 @@ class RDFGraphDataset(DGLBuiltinDataset): ...@@ -447,21 +427,6 @@ class RDFGraphDataset(DGLBuiltinDataset):
def num_classes(self): def num_classes(self):
return self._num_classes return self._num_classes
@property
def train_idx(self):
deprecate_property('dataset.train_idx', 'train_mask = g.ndata[\'train_mask\']')
return self._train_idx
@property
def test_idx(self):
deprecate_property('dataset.test_idx', 'train_mask = g.ndata[\'test_mask\']')
return self._test_idx
@property
def labels(self):
deprecate_property('dataset.labels', 'train_mask = g.ndata[\'labels\']')
return self._labels
@abc.abstractmethod @abc.abstractmethod
def parse_entity(self, term): def parse_entity(self, term):
"""Parse one entity from an RDF term. """Parse one entity from an RDF term.
...@@ -541,27 +506,6 @@ def _get_id(dict, key): ...@@ -541,27 +506,6 @@ def _get_id(dict, key):
class AIFBDataset(RDFGraphDataset): class AIFBDataset(RDFGraphDataset):
r"""AIFB dataset for node classification task r"""AIFB dataset for node classification task
.. deprecated:: 0.5.0
- ``graph`` is deprecated, it is replaced by:
>>> dataset = AIFBDataset()
>>> graph = dataset[0]
- ``train_idx`` is deprecated, it can be replaced by:
>>> dataset = AIFBDataset()
>>> graph = dataset[0]
>>> train_mask = graph.nodes[dataset.category].data['train_mask']
>>> train_idx = th.nonzero(train_mask, as_tuple=False).squeeze()
- ``test_idx`` is deprecated, it can be replaced by:
>>> dataset = AIFBDataset()
>>> graph = dataset[0]
>>> test_mask = graph.nodes[dataset.category].data['test_mask']
>>> test_idx = th.nonzero(test_mask, as_tuple=False).squeeze()
AIFB DataSet is a Semantic Web (RDF) dataset used as a benchmark in AIFB DataSet is a Semantic Web (RDF) dataset used as a benchmark in
data mining. It records the organizational structure of AIFB at the data mining. It records the organizational structure of AIFB at the
University of Karlsruhe. University of Karlsruhe.
...@@ -597,14 +541,6 @@ class AIFBDataset(RDFGraphDataset): ...@@ -597,14 +541,6 @@ class AIFBDataset(RDFGraphDataset):
Number of classes to predict Number of classes to predict
predict_category : str predict_category : str
The entity category (node type) that has labels for prediction The entity category (node type) that has labels for prediction
labels : Tensor
All the labels of the entities in ``predict_category``
graph : :class:`dgl.DGLGraph`
Graph structure
train_idx : Tensor
Entity IDs for training. All IDs are local IDs w.r.t. to ``predict_category``.
test_idx : Tensor
Entity IDs for testing. All IDs are local IDs w.r.t. to ``predict_category``.
Examples Examples
-------- --------
...@@ -613,9 +549,9 @@ class AIFBDataset(RDFGraphDataset): ...@@ -613,9 +549,9 @@ class AIFBDataset(RDFGraphDataset):
>>> category = dataset.predict_category >>> category = dataset.predict_category
>>> num_classes = dataset.num_classes >>> num_classes = dataset.num_classes
>>> >>>
>>> train_mask = g.nodes[category].data.pop('train_mask') >>> train_mask = g.nodes[category].data['train_mask']
>>> test_mask = g.nodes[category].data.pop('test_mask') >>> test_mask = g.nodes[category].data['test_mask']
>>> labels = g.nodes[category].data.pop('labels') >>> label = g.nodes[category].data['label']
""" """
entity_prefix = 'http://www.aifb.uni-karlsruhe.de/' entity_prefix = 'http://www.aifb.uni-karlsruhe.de/'
...@@ -656,7 +592,7 @@ class AIFBDataset(RDFGraphDataset): ...@@ -656,7 +592,7 @@ class AIFBDataset(RDFGraphDataset):
- ``ndata['train_mask']``: mask for training node set - ``ndata['train_mask']``: mask for training node set
- ``ndata['test_mask']``: mask for testing node set - ``ndata['test_mask']``: mask for testing node set
- ``ndata['labels']``: mask for labels - ``ndata['label']``: node labels
""" """
return super(AIFBDataset, self).__getitem__(idx) return super(AIFBDataset, self).__getitem__(idx)
...@@ -701,47 +637,9 @@ class AIFBDataset(RDFGraphDataset): ...@@ -701,47 +637,9 @@ class AIFBDataset(RDFGraphDataset):
person, _, label = line.strip().split('\t') person, _, label = line.strip().split('\t')
return person, label return person, label
class AIFB(AIFBDataset):
"""AIFB dataset. Same as AIFBDataset.
"""
def __init__(self,
print_every=10000,
insert_reverse=True,
raw_dir=None,
force_reload=False,
verbose=True):
deprecate_class('AIFB', 'AIFBDataset')
super(AIFB, self).__init__(print_every,
insert_reverse,
raw_dir,
force_reload,
verbose)
class MUTAGDataset(RDFGraphDataset): class MUTAGDataset(RDFGraphDataset):
r"""MUTAG dataset for node classification task r"""MUTAG dataset for node classification task
.. deprecated:: 0.5.0
- ``graph`` is deprecated, it is replaced by:
>>> dataset = MUTAGDataset()
>>> graph = dataset[0]
- ``train_idx`` is deprecated, it can be replaced by:
>>> dataset = MUTAGDataset()
>>> graph = dataset[0]
>>> train_mask = graph.nodes[dataset.category].data['train_mask']
>>> train_idx = th.nonzero(train_mask).squeeze()
- ``test_idx`` is deprecated, it can be replaced by:
>>> dataset = MUTAGDataset()
>>> graph = dataset[0]
>>> test_mask = graph.nodes[dataset.category].data['test_mask']
>>> test_idx = th.nonzero(test_mask).squeeze()
Mutag dataset statistics: Mutag dataset statistics:
- Nodes: 27163 - Nodes: 27163
...@@ -773,14 +671,8 @@ class MUTAGDataset(RDFGraphDataset): ...@@ -773,14 +671,8 @@ class MUTAGDataset(RDFGraphDataset):
Number of classes to predict Number of classes to predict
predict_category : str predict_category : str
The entity category (node type) that has labels for prediction The entity category (node type) that has labels for prediction
labels : Tensor
All the labels of the entities in ``predict_category``
graph : :class:`dgl.DGLGraph` graph : :class:`dgl.DGLGraph`
Graph structure Graph structure
train_idx : Tensor
Entity IDs for training. All IDs are local IDs w.r.t. to ``predict_category``.
test_idx : Tensor
Entity IDs for testing. All IDs are local IDs w.r.t. to ``predict_category``.
Examples Examples
-------- --------
...@@ -789,9 +681,9 @@ class MUTAGDataset(RDFGraphDataset): ...@@ -789,9 +681,9 @@ class MUTAGDataset(RDFGraphDataset):
>>> category = dataset.predict_category >>> category = dataset.predict_category
>>> num_classes = dataset.num_classes >>> num_classes = dataset.num_classes
>>> >>>
>>> train_mask = g.nodes[category].data.pop('train_mask') >>> train_mask = g.nodes[category].data['train_mask']
>>> test_mask = g.nodes[category].data.pop('test_mask') >>> test_mask = g.nodes[category].data['test_mask']
>>> labels = g.nodes[category].data.pop('labels') >>> label = g.nodes[category].data['label']
""" """
d_entity = re.compile("d[0-9]") d_entity = re.compile("d[0-9]")
...@@ -838,7 +730,7 @@ class MUTAGDataset(RDFGraphDataset): ...@@ -838,7 +730,7 @@ class MUTAGDataset(RDFGraphDataset):
- ``ndata['train_mask']``: mask for training node set - ``ndata['train_mask']``: mask for training node set
- ``ndata['test_mask']``: mask for testing node set - ``ndata['test_mask']``: mask for testing node set
- ``ndata['labels']``: mask for labels - ``ndata['label']``: node labels
""" """
return super(MUTAGDataset, self).__getitem__(idx) return super(MUTAGDataset, self).__getitem__(idx)
...@@ -900,46 +792,9 @@ class MUTAGDataset(RDFGraphDataset): ...@@ -900,46 +792,9 @@ class MUTAGDataset(RDFGraphDataset):
bond, _, label = line.strip().split('\t') bond, _, label = line.strip().split('\t')
return bond, label return bond, label
class MUTAG(MUTAGDataset):
"""MUTAG dataset. Same as MUTAGDataset.
"""
def __init__(self,
print_every=10000,
insert_reverse=True,
raw_dir=None,
force_reload=False,
verbose=True):
deprecate_class('MUTAG', 'MUTAGDataset')
super(MUTAG, self).__init__(print_every,
insert_reverse,
raw_dir,
force_reload,
verbose)
class BGSDataset(RDFGraphDataset): class BGSDataset(RDFGraphDataset):
r"""BGS dataset for node classification task r"""BGS dataset for node classification task
.. deprecated:: 0.5.0
- ``graph`` is deprecated, it is replaced by:
>>> dataset = BGSDataset()
>>> graph = dataset[0]
- ``train_idx`` is deprecated, it can be replaced by:
>>> dataset = BGSDataset()
>>> graph = dataset[0]
>>> train_mask = graph.nodes[dataset.category].data['train_mask']
>>> train_idx = th.nonzero(train_mask).squeeze()
- ``test_idx`` is deprecated, it can be replaced by:
>>> dataset = BGSDataset()
>>> graph = dataset[0]
>>> test_mask = graph.nodes[dataset.category].data['test_mask']
>>> test_idx = th.nonzero(test_mask).squeeze()
BGS namespace convention: BGS namespace convention:
``http://data.bgs.ac.uk/(ref|id)/<Major Concept>/<Sub Concept>/INSTANCE``. ``http://data.bgs.ac.uk/(ref|id)/<Major Concept>/<Sub Concept>/INSTANCE``.
We ignored all literal nodes and the relations connecting them in the We ignored all literal nodes and the relations connecting them in the
...@@ -976,15 +831,7 @@ class BGSDataset(RDFGraphDataset): ...@@ -976,15 +831,7 @@ class BGSDataset(RDFGraphDataset):
num_classes : int num_classes : int
Number of classes to predict Number of classes to predict
predict_category : str predict_category : str
The entity category (node type) that has labels for prediction
labels : Tensor
All the labels of the entities in ``predict_category`` All the labels of the entities in ``predict_category``
graph : :class:`dgl.DGLGraph`
Graph structure
train_idx : Tensor
Entity IDs for training. All IDs are local IDs w.r.t. to ``predict_category``.
test_idx : Tensor
Entity IDs for testing. All IDs are local IDs w.r.t. to ``predict_category``.
Examples Examples
-------- --------
...@@ -993,9 +840,9 @@ class BGSDataset(RDFGraphDataset): ...@@ -993,9 +840,9 @@ class BGSDataset(RDFGraphDataset):
>>> category = dataset.predict_category >>> category = dataset.predict_category
>>> num_classes = dataset.num_classes >>> num_classes = dataset.num_classes
>>> >>>
>>> train_mask = g.nodes[category].data.pop('train_mask') >>> train_mask = g.nodes[category].data['train_mask']
>>> test_mask = g.nodes[category].data.pop('test_mask') >>> test_mask = g.nodes[category].data['test_mask']
>>> labels = g.nodes[category].data.pop('labels') >>> label = g.nodes[category].data['label']
""" """
entity_prefix = 'http://data.bgs.ac.uk/' entity_prefix = 'http://data.bgs.ac.uk/'
...@@ -1036,7 +883,7 @@ class BGSDataset(RDFGraphDataset): ...@@ -1036,7 +883,7 @@ class BGSDataset(RDFGraphDataset):
- ``ndata['train_mask']``: mask for training node set - ``ndata['train_mask']``: mask for training node set
- ``ndata['test_mask']``: mask for testing node set - ``ndata['test_mask']``: mask for testing node set
- ``ndata['labels']``: mask for labels - ``ndata['label']``: node labels
""" """
return super(BGSDataset, self).__getitem__(idx) return super(BGSDataset, self).__getitem__(idx)
...@@ -1093,47 +940,9 @@ class BGSDataset(RDFGraphDataset): ...@@ -1093,47 +940,9 @@ class BGSDataset(RDFGraphDataset):
_, rock, label = line.strip().split('\t') _, rock, label = line.strip().split('\t')
return rock, label return rock, label
class BGS(BGSDataset):
"""BGS dataset. Same as BGSDataset.
"""
def __init__(self,
print_every=10000,
insert_reverse=True,
raw_dir=None,
force_reload=False,
verbose=True):
deprecate_class('BGS', 'BGSDataset')
super(BGS, self).__init__(print_every,
insert_reverse,
raw_dir,
force_reload,
verbose)
class AMDataset(RDFGraphDataset): class AMDataset(RDFGraphDataset):
"""AM dataset. for node classification task """AM dataset. for node classification task
.. deprecated:: 0.5.0
- ``graph`` is deprecated, it is replaced by:
>>> dataset = AMDataset()
>>> graph = dataset[0]
- ``train_idx`` is deprecated, it can be replaced by:
>>> dataset = AMDataset()
>>> graph = dataset[0]
>>> train_mask = graph.nodes[dataset.category].data['train_mask']
>>> train_idx = th.nonzero(train_mask).squeeze()
- ``test_idx`` is deprecated, it can be replaced by:
>>> dataset = AMDataset()
>>> graph = dataset[0]
>>> test_mask = graph.nodes[dataset.category].data['test_mask']
>>> test_idx = th.nonzero(test_mask).squeeze()
Namespace convention: Namespace convention:
- Instance: ``http://purl.org/collections/nl/am/<type>-<id>`` - Instance: ``http://purl.org/collections/nl/am/<type>-<id>``
...@@ -1173,14 +982,6 @@ class AMDataset(RDFGraphDataset): ...@@ -1173,14 +982,6 @@ class AMDataset(RDFGraphDataset):
Number of classes to predict Number of classes to predict
predict_category : str predict_category : str
The entity category (node type) that has labels for prediction The entity category (node type) that has labels for prediction
labels : Tensor
All the labels of the entities in ``predict_category``
graph : :class:`dgl.DGLGraph`
Graph structure
train_idx : Tensor
Entity IDs for training. All IDs are local IDs w.r.t. to ``predict_category``.
test_idx : Tensor
Entity IDs for testing. All IDs are local IDs w.r.t. to ``predict_category``.
Examples Examples
-------- --------
...@@ -1189,9 +990,9 @@ class AMDataset(RDFGraphDataset): ...@@ -1189,9 +990,9 @@ class AMDataset(RDFGraphDataset):
>>> category = dataset.predict_category >>> category = dataset.predict_category
>>> num_classes = dataset.num_classes >>> num_classes = dataset.num_classes
>>> >>>
>>> train_mask = g.nodes[category].data.pop('train_mask') >>> train_mask = g.nodes[category].data['train_mask']
>>> test_mask = g.nodes[category].data.pop('test_mask') >>> test_mask = g.nodes[category].data['test_mask']
>>> labels = g.nodes[category].data.pop('labels') >>> label = g.nodes[category].data['label']
""" """
entity_prefix = 'http://purl.org/collections/nl/am/' entity_prefix = 'http://purl.org/collections/nl/am/'
...@@ -1232,7 +1033,7 @@ class AMDataset(RDFGraphDataset): ...@@ -1232,7 +1033,7 @@ class AMDataset(RDFGraphDataset):
- ``ndata['train_mask']``: mask for training node set - ``ndata['train_mask']``: mask for training node set
- ``ndata['test_mask']``: mask for testing node set - ``ndata['test_mask']``: mask for testing node set
- ``ndata['labels']``: mask for labels - ``ndata['label']``: node labels
""" """
return super(AMDataset, self).__getitem__(idx) return super(AMDataset, self).__getitem__(idx)
...@@ -1287,22 +1088,3 @@ class AMDataset(RDFGraphDataset): ...@@ -1287,22 +1088,3 @@ class AMDataset(RDFGraphDataset):
def process_idx_file_line(self, line): def process_idx_file_line(self, line):
proxy, _, label = line.strip().split('\t') proxy, _, label = line.strip().split('\t')
return proxy, label return proxy, label
class AM(AMDataset):
"""AM dataset. Same as AMDataset.
"""
def __init__(self,
print_every=10000,
insert_reverse=True,
raw_dir=None,
force_reload=False,
verbose=True):
deprecate_class('AM', 'AMDataset')
super(AM, self).__init__(print_every,
insert_reverse,
raw_dir,
force_reload,
verbose)
if __name__ == '__main__':
dataset = AIFB()
...@@ -19,8 +19,10 @@ from .tensor_serialize import save_tensors, load_tensors ...@@ -19,8 +19,10 @@ from .tensor_serialize import save_tensors, load_tensors
from .. import backend as F from .. import backend as F
__all__ = ['loadtxt','download', 'check_sha1', 'extract_archive', __all__ = ['loadtxt','download', 'check_sha1', 'extract_archive',
'get_download_dir', 'Subset', 'split_dataset', 'get_download_dir', 'Subset', 'split_dataset', 'save_graphs',
'save_graphs', "load_graphs", "load_labels", "save_tensors", "load_tensors"] 'load_graphs', 'load_labels', 'save_tensors', 'load_tensors',
'add_nodepred_split',
]
def loadtxt(path, delimiter, dtype=None): def loadtxt(path, delimiter, dtype=None):
try: try:
...@@ -351,3 +353,47 @@ class Subset(object): ...@@ -351,3 +353,47 @@ class Subset(object):
Number of datapoints in the subset Number of datapoints in the subset
""" """
return len(self.indices) return len(self.indices)
def add_nodepred_split(dataset, ratio, ntype=None):
"""Split the given dataset into training, validation and test sets for
transductive node predction task.
It adds three node mask arrays ``'train_mask'``, ``'val_mask'`` and ``'test_mask'``,
to each graph in the dataset. Each sample in the dataset thus must be a :class:`DGLGraph`.
Fix the random seed of NumPy to make the result deterministic::
numpy.random.seed(42)
Parameters
----------
dataset : DGLDataset
The dataset to modify.
ratio : (float, float, float)
Split ratios for training, validation and test sets. Must sum to one.
ntype : str, optional
The node type to add mask for.
Examples
--------
>>> dataset = dgl.data.AmazonCoBuyComputerDataset()
>>> print('train_mask' in dataset[0].ndata)
False
>>> dgl.data.utils.add_nodepred_split(dataset, [0.8, 0.1, 0.1])
>>> print('train_mask' in dataset[0].ndata)
True
"""
if len(ratio) != 3:
raise ValueError(f'Split ratio must be a float triplet but got {ratio}.')
for i in range(len(dataset)):
g = dataset[i]
n = g.num_nodes(ntype)
idx = np.arange(0, n)
np.random.shuffle(idx)
n_train, n_val, n_test = int(n * ratio[0]), int(n * ratio[1]), int(n * ratio[2])
train_mask = generate_mask_tensor(idx2mask(idx[:n_train], n))
val_mask = generate_mask_tensor(idx2mask(idx[n_train:n_train + n_val], n))
test_mask = generate_mask_tensor(idx2mask(idx[n_train + n_val:], n))
g.nodes[ntype].data['train_mask'] = train_mask
g.nodes[ntype].data['val_mask'] = val_mask
g.nodes[ntype].data['test_mask'] = test_mask
...@@ -1013,6 +1013,59 @@ def test_csvdataset(): ...@@ -1013,6 +1013,59 @@ def test_csvdataset():
_test_DGLCSVDataset_multiple() _test_DGLCSVDataset_multiple()
_test_DGLCSVDataset_customized_data_parser() _test_DGLCSVDataset_customized_data_parser()
@unittest.skipIf(F._default_context_str == 'gpu', reason="Datasets don't need to be tested on GPU.")
def test_add_nodepred_split():
dataset = data.AmazonCoBuyComputerDataset()
print('train_mask' in dataset[0].ndata)
data.utils.add_nodepred_split(dataset, [0.8, 0.1, 0.1])
assert 'train_mask' in dataset[0].ndata
dataset = data.AIFBDataset()
print('train_mask' in dataset[0].nodes['Publikationen'].data)
data.utils.add_nodepred_split(dataset, [0.8, 0.1, 0.1], ntype='Publikationen')
assert 'train_mask' in dataset[0].nodes['Publikationen'].data
@unittest.skipIf(F._default_context_str == 'gpu', reason="Datasets don't need to be tested on GPU.")
def test_as_nodepred1():
ds = data.AmazonCoBuyComputerDataset()
print('train_mask' in ds[0].ndata)
new_ds = data.AsNodePredDataset(ds, [0.8, 0.1, 0.1], verbose=True)
assert len(new_ds) == 1
assert new_ds[0].num_nodes() == ds[0].num_nodes()
assert new_ds[0].num_edges() == ds[0].num_edges()
assert 'train_mask' in new_ds[0].ndata
ds = data.AIFBDataset()
print('train_mask' in ds[0].nodes['Personen'].data)
new_ds = data.AsNodePredDataset(ds, [0.8, 0.1, 0.1], 'Personen', verbose=True)
assert len(new_ds) == 1
assert new_ds[0].ntypes == ds[0].ntypes
assert new_ds[0].canonical_etypes == ds[0].canonical_etypes
assert 'train_mask' in new_ds[0].nodes['Personen'].data
@unittest.skipIf(F._default_context_str == 'gpu', reason="Datasets don't need to be tested on GPU.")
def test_as_nodepred2():
# test proper reprocessing
# create
ds = data.AsNodePredDataset(data.AmazonCoBuyComputerDataset(), [0.8, 0.1, 0.1])
assert F.sum(F.astype(ds[0].ndata['train_mask'], F.int32), 0) == int(ds[0].num_nodes() * 0.8)
# read from cache
ds = data.AsNodePredDataset(data.AmazonCoBuyComputerDataset(), [0.8, 0.1, 0.1])
assert F.sum(F.astype(ds[0].ndata['train_mask'], F.int32), 0) == int(ds[0].num_nodes() * 0.8)
# invalid cache, re-read
ds = data.AsNodePredDataset(data.AmazonCoBuyComputerDataset(), [0.1, 0.1, 0.8])
assert F.sum(F.astype(ds[0].ndata['train_mask'], F.int32), 0) == int(ds[0].num_nodes() * 0.1)
# create
ds = data.AsNodePredDataset(data.AIFBDataset(), [0.8, 0.1, 0.1], 'Personen', verbose=True)
assert F.sum(F.astype(ds[0].nodes['Personen'].data['train_mask'], F.int32), 0) == int(ds[0].num_nodes('Personen') * 0.8)
# read from cache
ds = data.AsNodePredDataset(data.AIFBDataset(), [0.8, 0.1, 0.1], 'Personen', verbose=True)
assert F.sum(F.astype(ds[0].nodes['Personen'].data['train_mask'], F.int32), 0) == int(ds[0].num_nodes('Personen') * 0.8)
# invalid cache, re-read
ds = data.AsNodePredDataset(data.AIFBDataset(), [0.1, 0.1, 0.8], 'Personen', verbose=True)
assert F.sum(F.astype(ds[0].nodes['Personen'].data['train_mask'], F.int32), 0) == int(ds[0].num_nodes('Personen') * 0.1)
if __name__ == '__main__': if __name__ == '__main__':
test_minigc() test_minigc()
...@@ -1023,3 +1076,6 @@ if __name__ == '__main__': ...@@ -1023,3 +1076,6 @@ if __name__ == '__main__':
test_fakenews() test_fakenews()
test_extract_archive() test_extract_archive()
test_csvdataset() test_csvdataset()
test_add_nodepred_split()
test_as_nodepred1()
test_as_nodepred2()
...@@ -14,7 +14,7 @@ SET DGLBACKEND=!BACKEND! ...@@ -14,7 +14,7 @@ SET DGLBACKEND=!BACKEND!
SET DGL_LIBRARY_PATH=!CD!\build SET DGL_LIBRARY_PATH=!CD!\build
SET DGL_DOWNLOAD_DIR=!CD! SET DGL_DOWNLOAD_DIR=!CD!
python -m pip install pytest pyyaml pandas pydantic || EXIT /B 1 python -m pip install pytest pyyaml pandas pydantic rdflib || EXIT /B 1
python -m pytest -v --junitxml=pytest_backend.xml tests\!DGLBACKEND! || EXIT /B 1 python -m pytest -v --junitxml=pytest_backend.xml tests\!DGLBACKEND! || EXIT /B 1
python -m pytest -v --junitxml=pytest_compute.xml tests\compute || EXIT /B 1 python -m pytest -v --junitxml=pytest_compute.xml tests\compute || EXIT /B 1
ENDLOCAL ENDLOCAL
......
...@@ -32,7 +32,7 @@ fi ...@@ -32,7 +32,7 @@ fi
conda activate ${DGLBACKEND}-ci conda activate ${DGLBACKEND}-ci
python3 -m pip install pytest pyyaml pandas pydantic || EXIT /B 1 python3 -m pip install pytest pyyaml pandas pydantic rdflib || EXIT /B 1
python3 -m pytest -v --junitxml=pytest_compute.xml tests/compute || fail "compute" python3 -m pytest -v --junitxml=pytest_compute.xml tests/compute || fail "compute"
python3 -m pytest -v --junitxml=pytest_backend.xml tests/$DGLBACKEND || fail "backend-specific" python3 -m pytest -v --junitxml=pytest_backend.xml tests/$DGLBACKEND || fail "backend-specific"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment