Unverified Commit 5152a879 authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

[Data] Utility function and class for converting a dataset for node prediction (#3695)

* add ut

* add doc link

* install dep

* fix  ci

* fix ut; more comments

* remove deprecated attributes in rdf datasets; fix label feature name

* address comments

* fix ut for other frameworks
parent 56b5d0e5
......@@ -227,6 +227,13 @@ Fake news dataset
.. autoclass:: FakeNewsDataset
:members: __getitem__, __len__
Dataset adapters
```````````````````````````````````
.. autoclass:: AsNodePredDataset
:members: __getitem__, __len__
Utilities
-----------------
......@@ -241,6 +248,7 @@ Utilities
utils.load_labels
utils.save_info
utils.load_info
utils.add_nodepred_split
.. autoclass:: dgl.data.utils.Subset
:members: __getitem__, __len__
......@@ -30,6 +30,7 @@ from .rdf import AIFBDataset, MUTAGDataset, BGSDataset, AMDataset
from .fraud import FraudDataset, FraudYelpDataset, FraudAmazonDataset
from .fakenews import FakeNewsDataset
from .csv_dataset import DGLCSVDataset
from .adapter import AsNodePredDataset
def register_data_args(parser):
parser.add_argument(
......
"""Dataset adapters for re-purposing a dataset for a different kind of training task."""
import os
import json
from .dgl_dataset import DGLDataset
from . import utils
__all__ = ['AsNodePredDataset']
class AsNodePredDataset(DGLDataset):
"""Repurpose a dataset for a standard semi-supervised transductive
node prediction task.
The class converts a given dataset into a new dataset object that:
- Contains only one graph, accessible from ``dataset[0]``.
- The graph stores:
- Node labels in ``g.ndata['label']``.
- Train/val/test masks in ``g.ndata['train_mask']``, ``g.ndata['val_mask']``,
and ``g.ndata['test_mask']`` respectively.
- In addition, the dataset contains the following attributes:
- ``num_classes``, the number of classes to predict.
If the input dataset contains heterogeneous graphs, users need to specify the
``target_ntype`` argument to indicate which node type to make predictions for.
In this case:
- Node labels are stored in ``g.nodes[target_ntype].data['label']``.
- Training masks are stored in ``g.nodes[target_ntype].data['train_mask']``.
So do validation and test masks.
The class will keep only the first graph in the provided dataset and
generate train/val/test masks according to the given spplit ratio. The generated
masks will be cached to disk for fast re-loading. If the provided split ratio
differs from the cached one, it will re-process the dataset properly.
Parameters
----------
dataset : DGLDataset
The dataset to be converted.
split_ratio : (float, float, float), optional
Split ratios for training, validation and test sets. Must sum to one.
target_ntype : str, optional
The node type to add split mask for.
Attributes
----------
num_classes : int
Number of classes to predict.
Examples
--------
>>> ds = dgl.data.AmazonCoBuyComputerDataset()
>>> print(ds)
Dataset("amazon_co_buy_computer", num_graphs=1, save_path=...)
>>> new_ds = dgl.data.AsNodePredDataset(ds, [0.8, 0.1, 0.1])
>>> print(new_ds)
Dataset("amazon_co_buy_computer-as-nodepred", num_graphs=1, save_path=...)
>>> print('train_mask' in new_ds[0].ndata)
True
"""
def __init__(self,
dataset,
split_ratio=[0.8, 0.1, 0.1],
target_ntype=None,
**kwargs):
self.g = dataset[0].clone()
self.split_ratio = split_ratio
self.target_ntype = target_ntype
self.num_classes = dataset.num_classes
super().__init__(dataset.name + '-as-nodepred', **kwargs)
def process(self):
if 'label' not in self.g.nodes[self.target_ntype].data:
raise ValueError("Missing node labels. Make sure labels are stored "
"under name 'label'.")
if self.verbose:
print('Generating train/val/test masks...')
utils.add_nodepred_split(self, self.split_ratio, self.target_ntype)
def has_cache(self):
return os.path.isfile(os.path.join(self.save_path, 'graph.bin'))
def load(self):
with open(os.path.join(self.save_path, 'info.json'), 'r') as f:
info = json.load(f)
if (info['split_ratio'] != self.split_ratio
or info['target_ntype'] != self.target_ntype):
raise ValueError('Provided split ratio is different from the cached file. '
'Re-process the dataset.')
self.split_ratio = info['split_ratio']
self.target_ntype = info['target_ntype']
self.num_classes = info['num_classes']
gs, _ = utils.load_graphs(os.path.join(self.save_path, 'graph.bin'))
self.g = gs[0]
def save(self):
utils.save_graphs(os.path.join(self.save_path, 'graph.bin'), [self.g])
with open(os.path.join(self.save_path, 'info.json'), 'w') as f:
json.dump({
'split_ratio' : self.split_ratio,
'target_ntype' : self.target_ntype,
'num_classes' : self.num_classes}, f)
def __getitem__(self, idx):
return self.g
def __len__(self):
return 1
......@@ -57,16 +57,17 @@ class DGLDataset(object):
name : str
The dataset name
raw_dir : str
Raw file directory contains the input data folder
Directory to store all the downloaded raw datasets.
raw_path : str
Directory contains the input data files.
Default : ``os.path.join(self.raw_dir, self.name)``
Path to the downloaded raw dataset folder. An alias for
``os.path.join(self.raw_dir, self.name)``.
save_dir : str
Directory to save the processed dataset
Directory to save all the processed datasets.
save_path : str
File path to save the processed dataset
Path to the processed dataset folder. An alias for
``os.path.join(self.save_dir, self.name)``.
verbose : bool
Whether to print information
Whether to print more runtime information.
hash : str
Hash value for the dataset and the setting.
"""
......@@ -123,10 +124,11 @@ class DGLDataset(object):
"""
pass
@abc.abstractmethod
def process(self):
r"""Overwrite to realize your own logic of processing the input data.
"""
raise NotImplementedError
pass
def has_cache(self):
r"""Overwrite to realize your own logic of
......@@ -138,7 +140,9 @@ class DGLDataset(object):
@retry_method_with_fix(download)
def _download(self):
r"""Download dataset by calling ``self.download()`` if the dataset does not exists under ``self.raw_path``.
"""Download dataset by calling ``self.download()``
if the dataset does not exists under ``self.raw_path``.
By default ``self.raw_path = os.path.join(self.raw_dir, self.name)``
One can overwrite ``raw_path()`` function to change the path.
"""
......@@ -149,14 +153,18 @@ class DGLDataset(object):
self.download()
def _load(self):
r"""Entry point from __init__ to load the dataset.
if the cache exists:
Load the dataset from saved dgl graph and information files.
If loadin process fails, re-download and process the dataset.
"""Entry point from __init__ to load the dataset.
If cache exists:
- Load the dataset from saved dgl graph and information files.
- If loadin process fails, re-download and process the dataset.
else:
1. Download the dataset if needed.
2. Process the dataset and build the dgl graph.
3. Save the processed dataset into files.
- Download the dataset if needed.
- Process the dataset and build the dgl graph.
- Save the processed dataset into files.
"""
load_flag = not self._force_reload and self.has_cache()
......@@ -255,6 +263,10 @@ class DGLDataset(object):
r"""The number of examples in the dataset."""
pass
def __repr__(self):
return f'Dataset("{self.name}", num_graphs={len(self)},' + \
f' save_path={self.save_path})'
class DGLBuiltinDataset(DGLDataset):
r"""The Basic DGL Builtin Dataset.
......
......@@ -16,9 +16,9 @@ import dgl
import dgl.backend as F
from .dgl_dataset import DGLBuiltinDataset
from .utils import save_graphs, load_graphs, save_info, load_info, _get_dgl_url
from .utils import generate_mask_tensor, idx2mask, deprecate_property, deprecate_class
from .utils import generate_mask_tensor, idx2mask
__all__ = ['AIFB', 'MUTAG', 'BGS', 'AM', 'AIFBDataset', 'MUTAGDataset', 'BGSDataset', 'AMDataset']
__all__ = ['AIFBDataset', 'MUTAGDataset', 'BGSDataset', 'AMDataset']
# Dictionary for renaming reserved node/edge type names to the ones
# that are allowed by nn.Module.
......@@ -72,18 +72,10 @@ class RDFGraphDataset(DGLBuiltinDataset):
Attributes
----------
graph : dgl.DGLraph
Graph structure
num_classes : int
Number of classes to predict
predict_category : str
The entity category (node type) that has labels for prediction
train_idx : Tensor
Entity IDs for training. All IDs are local IDs w.r.t. to ``predict_category``.
test_idx : Tensor
Entity IDs for testing. All IDs are local IDs w.r.t. to ``predict_category``.
labels : Tensor
All the labels of the entities in ``predict_category``
Parameters
----------
......@@ -243,14 +235,11 @@ class RDFGraphDataset(DGLBuiltinDataset):
test_mask = generate_mask_tensor(test_mask)
self._hg.nodes[self.predict_category].data['train_mask'] = train_mask
self._hg.nodes[self.predict_category].data['test_mask'] = test_mask
# TODO(minjie): Deprecate 'labels', use 'label' for consistency.
self._hg.nodes[self.predict_category].data['labels'] = labels
self._hg.nodes[self.predict_category].data['label'] = labels
self._num_classes = num_classes
# save for compatability
self._train_idx = F.tensor(train_idx)
self._test_idx = F.tensor(test_idx)
self._labels = labels
def build_graph(self, mg, src, dst, ntid, etid, ntypes, etypes):
"""Build the graphs
......@@ -411,14 +400,10 @@ class RDFGraphDataset(DGLBuiltinDataset):
self._num_classes = info['num_classes']
self._predict_category = info['predict_category']
self._hg = graphs[0]
train_mask = self._hg.nodes[self.predict_category].data['train_mask']
test_mask = self._hg.nodes[self.predict_category].data['test_mask']
self._labels = self._hg.nodes[self.predict_category].data['labels']
train_idx = F.nonzero_1d(train_mask)
test_idx = F.nonzero_1d(test_mask)
self._train_idx = train_idx
self._test_idx = test_idx
# For backward compatibility
if 'label' not in self._hg.nodes[self.predict_category].data:
self._hg.nodes[self.predict_category].data['label'] = \
self._hg.nodes[self.predict_category].data['labels']
def __getitem__(self, idx):
r"""Gets the graph object
......@@ -434,11 +419,6 @@ class RDFGraphDataset(DGLBuiltinDataset):
def save_name(self):
return self.name + '_dgl_graph'
@property
def graph(self):
deprecate_property('dataset.graph', 'hg = dataset[0]')
return self._hg
@property
def predict_category(self):
return self._predict_category
......@@ -447,21 +427,6 @@ class RDFGraphDataset(DGLBuiltinDataset):
def num_classes(self):
return self._num_classes
@property
def train_idx(self):
deprecate_property('dataset.train_idx', 'train_mask = g.ndata[\'train_mask\']')
return self._train_idx
@property
def test_idx(self):
deprecate_property('dataset.test_idx', 'train_mask = g.ndata[\'test_mask\']')
return self._test_idx
@property
def labels(self):
deprecate_property('dataset.labels', 'train_mask = g.ndata[\'labels\']')
return self._labels
@abc.abstractmethod
def parse_entity(self, term):
"""Parse one entity from an RDF term.
......@@ -541,27 +506,6 @@ def _get_id(dict, key):
class AIFBDataset(RDFGraphDataset):
r"""AIFB dataset for node classification task
.. deprecated:: 0.5.0
- ``graph`` is deprecated, it is replaced by:
>>> dataset = AIFBDataset()
>>> graph = dataset[0]
- ``train_idx`` is deprecated, it can be replaced by:
>>> dataset = AIFBDataset()
>>> graph = dataset[0]
>>> train_mask = graph.nodes[dataset.category].data['train_mask']
>>> train_idx = th.nonzero(train_mask, as_tuple=False).squeeze()
- ``test_idx`` is deprecated, it can be replaced by:
>>> dataset = AIFBDataset()
>>> graph = dataset[0]
>>> test_mask = graph.nodes[dataset.category].data['test_mask']
>>> test_idx = th.nonzero(test_mask, as_tuple=False).squeeze()
AIFB DataSet is a Semantic Web (RDF) dataset used as a benchmark in
data mining. It records the organizational structure of AIFB at the
University of Karlsruhe.
......@@ -597,14 +541,6 @@ class AIFBDataset(RDFGraphDataset):
Number of classes to predict
predict_category : str
The entity category (node type) that has labels for prediction
labels : Tensor
All the labels of the entities in ``predict_category``
graph : :class:`dgl.DGLGraph`
Graph structure
train_idx : Tensor
Entity IDs for training. All IDs are local IDs w.r.t. to ``predict_category``.
test_idx : Tensor
Entity IDs for testing. All IDs are local IDs w.r.t. to ``predict_category``.
Examples
--------
......@@ -613,9 +549,9 @@ class AIFBDataset(RDFGraphDataset):
>>> category = dataset.predict_category
>>> num_classes = dataset.num_classes
>>>
>>> train_mask = g.nodes[category].data.pop('train_mask')
>>> test_mask = g.nodes[category].data.pop('test_mask')
>>> labels = g.nodes[category].data.pop('labels')
>>> train_mask = g.nodes[category].data['train_mask']
>>> test_mask = g.nodes[category].data['test_mask']
>>> label = g.nodes[category].data['label']
"""
entity_prefix = 'http://www.aifb.uni-karlsruhe.de/'
......@@ -656,7 +592,7 @@ class AIFBDataset(RDFGraphDataset):
- ``ndata['train_mask']``: mask for training node set
- ``ndata['test_mask']``: mask for testing node set
- ``ndata['labels']``: mask for labels
- ``ndata['label']``: node labels
"""
return super(AIFBDataset, self).__getitem__(idx)
......@@ -701,47 +637,9 @@ class AIFBDataset(RDFGraphDataset):
person, _, label = line.strip().split('\t')
return person, label
class AIFB(AIFBDataset):
"""AIFB dataset. Same as AIFBDataset.
"""
def __init__(self,
print_every=10000,
insert_reverse=True,
raw_dir=None,
force_reload=False,
verbose=True):
deprecate_class('AIFB', 'AIFBDataset')
super(AIFB, self).__init__(print_every,
insert_reverse,
raw_dir,
force_reload,
verbose)
class MUTAGDataset(RDFGraphDataset):
r"""MUTAG dataset for node classification task
.. deprecated:: 0.5.0
- ``graph`` is deprecated, it is replaced by:
>>> dataset = MUTAGDataset()
>>> graph = dataset[0]
- ``train_idx`` is deprecated, it can be replaced by:
>>> dataset = MUTAGDataset()
>>> graph = dataset[0]
>>> train_mask = graph.nodes[dataset.category].data['train_mask']
>>> train_idx = th.nonzero(train_mask).squeeze()
- ``test_idx`` is deprecated, it can be replaced by:
>>> dataset = MUTAGDataset()
>>> graph = dataset[0]
>>> test_mask = graph.nodes[dataset.category].data['test_mask']
>>> test_idx = th.nonzero(test_mask).squeeze()
Mutag dataset statistics:
- Nodes: 27163
......@@ -773,14 +671,8 @@ class MUTAGDataset(RDFGraphDataset):
Number of classes to predict
predict_category : str
The entity category (node type) that has labels for prediction
labels : Tensor
All the labels of the entities in ``predict_category``
graph : :class:`dgl.DGLGraph`
Graph structure
train_idx : Tensor
Entity IDs for training. All IDs are local IDs w.r.t. to ``predict_category``.
test_idx : Tensor
Entity IDs for testing. All IDs are local IDs w.r.t. to ``predict_category``.
Examples
--------
......@@ -789,9 +681,9 @@ class MUTAGDataset(RDFGraphDataset):
>>> category = dataset.predict_category
>>> num_classes = dataset.num_classes
>>>
>>> train_mask = g.nodes[category].data.pop('train_mask')
>>> test_mask = g.nodes[category].data.pop('test_mask')
>>> labels = g.nodes[category].data.pop('labels')
>>> train_mask = g.nodes[category].data['train_mask']
>>> test_mask = g.nodes[category].data['test_mask']
>>> label = g.nodes[category].data['label']
"""
d_entity = re.compile("d[0-9]")
......@@ -838,7 +730,7 @@ class MUTAGDataset(RDFGraphDataset):
- ``ndata['train_mask']``: mask for training node set
- ``ndata['test_mask']``: mask for testing node set
- ``ndata['labels']``: mask for labels
- ``ndata['label']``: node labels
"""
return super(MUTAGDataset, self).__getitem__(idx)
......@@ -900,46 +792,9 @@ class MUTAGDataset(RDFGraphDataset):
bond, _, label = line.strip().split('\t')
return bond, label
class MUTAG(MUTAGDataset):
"""MUTAG dataset. Same as MUTAGDataset.
"""
def __init__(self,
print_every=10000,
insert_reverse=True,
raw_dir=None,
force_reload=False,
verbose=True):
deprecate_class('MUTAG', 'MUTAGDataset')
super(MUTAG, self).__init__(print_every,
insert_reverse,
raw_dir,
force_reload,
verbose)
class BGSDataset(RDFGraphDataset):
r"""BGS dataset for node classification task
.. deprecated:: 0.5.0
- ``graph`` is deprecated, it is replaced by:
>>> dataset = BGSDataset()
>>> graph = dataset[0]
- ``train_idx`` is deprecated, it can be replaced by:
>>> dataset = BGSDataset()
>>> graph = dataset[0]
>>> train_mask = graph.nodes[dataset.category].data['train_mask']
>>> train_idx = th.nonzero(train_mask).squeeze()
- ``test_idx`` is deprecated, it can be replaced by:
>>> dataset = BGSDataset()
>>> graph = dataset[0]
>>> test_mask = graph.nodes[dataset.category].data['test_mask']
>>> test_idx = th.nonzero(test_mask).squeeze()
BGS namespace convention:
``http://data.bgs.ac.uk/(ref|id)/<Major Concept>/<Sub Concept>/INSTANCE``.
We ignored all literal nodes and the relations connecting them in the
......@@ -976,15 +831,7 @@ class BGSDataset(RDFGraphDataset):
num_classes : int
Number of classes to predict
predict_category : str
The entity category (node type) that has labels for prediction
labels : Tensor
All the labels of the entities in ``predict_category``
graph : :class:`dgl.DGLGraph`
Graph structure
train_idx : Tensor
Entity IDs for training. All IDs are local IDs w.r.t. to ``predict_category``.
test_idx : Tensor
Entity IDs for testing. All IDs are local IDs w.r.t. to ``predict_category``.
Examples
--------
......@@ -993,9 +840,9 @@ class BGSDataset(RDFGraphDataset):
>>> category = dataset.predict_category
>>> num_classes = dataset.num_classes
>>>
>>> train_mask = g.nodes[category].data.pop('train_mask')
>>> test_mask = g.nodes[category].data.pop('test_mask')
>>> labels = g.nodes[category].data.pop('labels')
>>> train_mask = g.nodes[category].data['train_mask']
>>> test_mask = g.nodes[category].data['test_mask']
>>> label = g.nodes[category].data['label']
"""
entity_prefix = 'http://data.bgs.ac.uk/'
......@@ -1036,7 +883,7 @@ class BGSDataset(RDFGraphDataset):
- ``ndata['train_mask']``: mask for training node set
- ``ndata['test_mask']``: mask for testing node set
- ``ndata['labels']``: mask for labels
- ``ndata['label']``: node labels
"""
return super(BGSDataset, self).__getitem__(idx)
......@@ -1093,47 +940,9 @@ class BGSDataset(RDFGraphDataset):
_, rock, label = line.strip().split('\t')
return rock, label
class BGS(BGSDataset):
"""BGS dataset. Same as BGSDataset.
"""
def __init__(self,
print_every=10000,
insert_reverse=True,
raw_dir=None,
force_reload=False,
verbose=True):
deprecate_class('BGS', 'BGSDataset')
super(BGS, self).__init__(print_every,
insert_reverse,
raw_dir,
force_reload,
verbose)
class AMDataset(RDFGraphDataset):
"""AM dataset. for node classification task
.. deprecated:: 0.5.0
- ``graph`` is deprecated, it is replaced by:
>>> dataset = AMDataset()
>>> graph = dataset[0]
- ``train_idx`` is deprecated, it can be replaced by:
>>> dataset = AMDataset()
>>> graph = dataset[0]
>>> train_mask = graph.nodes[dataset.category].data['train_mask']
>>> train_idx = th.nonzero(train_mask).squeeze()
- ``test_idx`` is deprecated, it can be replaced by:
>>> dataset = AMDataset()
>>> graph = dataset[0]
>>> test_mask = graph.nodes[dataset.category].data['test_mask']
>>> test_idx = th.nonzero(test_mask).squeeze()
Namespace convention:
- Instance: ``http://purl.org/collections/nl/am/<type>-<id>``
......@@ -1173,14 +982,6 @@ class AMDataset(RDFGraphDataset):
Number of classes to predict
predict_category : str
The entity category (node type) that has labels for prediction
labels : Tensor
All the labels of the entities in ``predict_category``
graph : :class:`dgl.DGLGraph`
Graph structure
train_idx : Tensor
Entity IDs for training. All IDs are local IDs w.r.t. to ``predict_category``.
test_idx : Tensor
Entity IDs for testing. All IDs are local IDs w.r.t. to ``predict_category``.
Examples
--------
......@@ -1189,9 +990,9 @@ class AMDataset(RDFGraphDataset):
>>> category = dataset.predict_category
>>> num_classes = dataset.num_classes
>>>
>>> train_mask = g.nodes[category].data.pop('train_mask')
>>> test_mask = g.nodes[category].data.pop('test_mask')
>>> labels = g.nodes[category].data.pop('labels')
>>> train_mask = g.nodes[category].data['train_mask']
>>> test_mask = g.nodes[category].data['test_mask']
>>> label = g.nodes[category].data['label']
"""
entity_prefix = 'http://purl.org/collections/nl/am/'
......@@ -1232,7 +1033,7 @@ class AMDataset(RDFGraphDataset):
- ``ndata['train_mask']``: mask for training node set
- ``ndata['test_mask']``: mask for testing node set
- ``ndata['labels']``: mask for labels
- ``ndata['label']``: node labels
"""
return super(AMDataset, self).__getitem__(idx)
......@@ -1287,22 +1088,3 @@ class AMDataset(RDFGraphDataset):
def process_idx_file_line(self, line):
proxy, _, label = line.strip().split('\t')
return proxy, label
class AM(AMDataset):
"""AM dataset. Same as AMDataset.
"""
def __init__(self,
print_every=10000,
insert_reverse=True,
raw_dir=None,
force_reload=False,
verbose=True):
deprecate_class('AM', 'AMDataset')
super(AM, self).__init__(print_every,
insert_reverse,
raw_dir,
force_reload,
verbose)
if __name__ == '__main__':
dataset = AIFB()
......@@ -19,8 +19,10 @@ from .tensor_serialize import save_tensors, load_tensors
from .. import backend as F
__all__ = ['loadtxt','download', 'check_sha1', 'extract_archive',
'get_download_dir', 'Subset', 'split_dataset',
'save_graphs', "load_graphs", "load_labels", "save_tensors", "load_tensors"]
'get_download_dir', 'Subset', 'split_dataset', 'save_graphs',
'load_graphs', 'load_labels', 'save_tensors', 'load_tensors',
'add_nodepred_split',
]
def loadtxt(path, delimiter, dtype=None):
try:
......@@ -351,3 +353,47 @@ class Subset(object):
Number of datapoints in the subset
"""
return len(self.indices)
def add_nodepred_split(dataset, ratio, ntype=None):
"""Split the given dataset into training, validation and test sets for
transductive node predction task.
It adds three node mask arrays ``'train_mask'``, ``'val_mask'`` and ``'test_mask'``,
to each graph in the dataset. Each sample in the dataset thus must be a :class:`DGLGraph`.
Fix the random seed of NumPy to make the result deterministic::
numpy.random.seed(42)
Parameters
----------
dataset : DGLDataset
The dataset to modify.
ratio : (float, float, float)
Split ratios for training, validation and test sets. Must sum to one.
ntype : str, optional
The node type to add mask for.
Examples
--------
>>> dataset = dgl.data.AmazonCoBuyComputerDataset()
>>> print('train_mask' in dataset[0].ndata)
False
>>> dgl.data.utils.add_nodepred_split(dataset, [0.8, 0.1, 0.1])
>>> print('train_mask' in dataset[0].ndata)
True
"""
if len(ratio) != 3:
raise ValueError(f'Split ratio must be a float triplet but got {ratio}.')
for i in range(len(dataset)):
g = dataset[i]
n = g.num_nodes(ntype)
idx = np.arange(0, n)
np.random.shuffle(idx)
n_train, n_val, n_test = int(n * ratio[0]), int(n * ratio[1]), int(n * ratio[2])
train_mask = generate_mask_tensor(idx2mask(idx[:n_train], n))
val_mask = generate_mask_tensor(idx2mask(idx[n_train:n_train + n_val], n))
test_mask = generate_mask_tensor(idx2mask(idx[n_train + n_val:], n))
g.nodes[ntype].data['train_mask'] = train_mask
g.nodes[ntype].data['val_mask'] = val_mask
g.nodes[ntype].data['test_mask'] = test_mask
......@@ -1013,6 +1013,59 @@ def test_csvdataset():
_test_DGLCSVDataset_multiple()
_test_DGLCSVDataset_customized_data_parser()
@unittest.skipIf(F._default_context_str == 'gpu', reason="Datasets don't need to be tested on GPU.")
def test_add_nodepred_split():
dataset = data.AmazonCoBuyComputerDataset()
print('train_mask' in dataset[0].ndata)
data.utils.add_nodepred_split(dataset, [0.8, 0.1, 0.1])
assert 'train_mask' in dataset[0].ndata
dataset = data.AIFBDataset()
print('train_mask' in dataset[0].nodes['Publikationen'].data)
data.utils.add_nodepred_split(dataset, [0.8, 0.1, 0.1], ntype='Publikationen')
assert 'train_mask' in dataset[0].nodes['Publikationen'].data
@unittest.skipIf(F._default_context_str == 'gpu', reason="Datasets don't need to be tested on GPU.")
def test_as_nodepred1():
ds = data.AmazonCoBuyComputerDataset()
print('train_mask' in ds[0].ndata)
new_ds = data.AsNodePredDataset(ds, [0.8, 0.1, 0.1], verbose=True)
assert len(new_ds) == 1
assert new_ds[0].num_nodes() == ds[0].num_nodes()
assert new_ds[0].num_edges() == ds[0].num_edges()
assert 'train_mask' in new_ds[0].ndata
ds = data.AIFBDataset()
print('train_mask' in ds[0].nodes['Personen'].data)
new_ds = data.AsNodePredDataset(ds, [0.8, 0.1, 0.1], 'Personen', verbose=True)
assert len(new_ds) == 1
assert new_ds[0].ntypes == ds[0].ntypes
assert new_ds[0].canonical_etypes == ds[0].canonical_etypes
assert 'train_mask' in new_ds[0].nodes['Personen'].data
@unittest.skipIf(F._default_context_str == 'gpu', reason="Datasets don't need to be tested on GPU.")
def test_as_nodepred2():
# test proper reprocessing
# create
ds = data.AsNodePredDataset(data.AmazonCoBuyComputerDataset(), [0.8, 0.1, 0.1])
assert F.sum(F.astype(ds[0].ndata['train_mask'], F.int32), 0) == int(ds[0].num_nodes() * 0.8)
# read from cache
ds = data.AsNodePredDataset(data.AmazonCoBuyComputerDataset(), [0.8, 0.1, 0.1])
assert F.sum(F.astype(ds[0].ndata['train_mask'], F.int32), 0) == int(ds[0].num_nodes() * 0.8)
# invalid cache, re-read
ds = data.AsNodePredDataset(data.AmazonCoBuyComputerDataset(), [0.1, 0.1, 0.8])
assert F.sum(F.astype(ds[0].ndata['train_mask'], F.int32), 0) == int(ds[0].num_nodes() * 0.1)
# create
ds = data.AsNodePredDataset(data.AIFBDataset(), [0.8, 0.1, 0.1], 'Personen', verbose=True)
assert F.sum(F.astype(ds[0].nodes['Personen'].data['train_mask'], F.int32), 0) == int(ds[0].num_nodes('Personen') * 0.8)
# read from cache
ds = data.AsNodePredDataset(data.AIFBDataset(), [0.8, 0.1, 0.1], 'Personen', verbose=True)
assert F.sum(F.astype(ds[0].nodes['Personen'].data['train_mask'], F.int32), 0) == int(ds[0].num_nodes('Personen') * 0.8)
# invalid cache, re-read
ds = data.AsNodePredDataset(data.AIFBDataset(), [0.1, 0.1, 0.8], 'Personen', verbose=True)
assert F.sum(F.astype(ds[0].nodes['Personen'].data['train_mask'], F.int32), 0) == int(ds[0].num_nodes('Personen') * 0.1)
if __name__ == '__main__':
test_minigc()
......@@ -1023,3 +1076,6 @@ if __name__ == '__main__':
test_fakenews()
test_extract_archive()
test_csvdataset()
test_add_nodepred_split()
test_as_nodepred1()
test_as_nodepred2()
......@@ -14,7 +14,7 @@ SET DGLBACKEND=!BACKEND!
SET DGL_LIBRARY_PATH=!CD!\build
SET DGL_DOWNLOAD_DIR=!CD!
python -m pip install pytest pyyaml pandas pydantic || EXIT /B 1
python -m pip install pytest pyyaml pandas pydantic rdflib || EXIT /B 1
python -m pytest -v --junitxml=pytest_backend.xml tests\!DGLBACKEND! || EXIT /B 1
python -m pytest -v --junitxml=pytest_compute.xml tests\compute || EXIT /B 1
ENDLOCAL
......
......@@ -32,7 +32,7 @@ fi
conda activate ${DGLBACKEND}-ci
python3 -m pip install pytest pyyaml pandas pydantic || EXIT /B 1
python3 -m pip install pytest pyyaml pandas pydantic rdflib || EXIT /B 1
python3 -m pytest -v --junitxml=pytest_compute.xml tests/compute || fail "compute"
python3 -m pytest -v --junitxml=pytest_backend.xml tests/$DGLBACKEND || fail "backend-specific"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment