Unverified Commit 2f4146a4 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[Dataset] add train_idx/val_idx/test_idx as dataset attributes (#3769)

* add train_idx/val_idx/test_idx as dataset attributes

* refine docstring
parent e7ad4c9c
......@@ -234,11 +234,16 @@ Fake news dataset
Dataset adapters
-------------------
Node prediction adapter
```````````````````````
.. autoclass:: AsNodePredDataset
:members: __getitem__, __len__
Link prediction adapter
```````````````````````
.. autoclass:: AsEdgePredDataset
.. autoclass:: AsLinkPredDataset
:members: __getitem__, __len__
......
......@@ -22,11 +22,14 @@ class AsNodePredDataset(DGLDataset):
- Contains only one graph, accessible from ``dataset[0]``.
- The graph stores:
- Node labels in ``g.ndata['label']``.
- Train/val/test masks in ``g.ndata['train_mask']``, ``g.ndata['val_mask']``,
and ``g.ndata['test_mask']`` respectively.
- In addition, the dataset contains the following attributes:
- ``num_classes``, the number of classes to predict.
- ``train_idx``, ``val_idx``, ``test_idx``, train/val/test indexes.
If the input dataset contains heterogeneous graphs, users need to specify the
``target_ntype`` argument to indicate which node type to make predictions for.
......@@ -54,6 +57,12 @@ class AsNodePredDataset(DGLDataset):
----------
num_classes : int
Number of classes to predict.
train_idx : Tensor
An 1-D integer tensor of training node IDs.
val_idx : Tensor
An 1-D integer tensor of validation node IDs.
test_idx : Tensor
An 1-D integer tensor of test node IDs.
Examples
--------
......@@ -114,6 +123,8 @@ class AsNodePredDataset(DGLDataset):
print('Generating train/val/test masks...')
utils.add_nodepred_split(self, self.split_ratio, self.target_ntype)
self._set_split_index()
self.num_classes = getattr(self.dataset, 'num_classes', None)
if self.num_classes is None:
self.num_classes = len(F.unique(self.g.nodes[self.target_ntype].data['label']))
......@@ -133,6 +144,7 @@ class AsNodePredDataset(DGLDataset):
self.num_classes = info['num_classes']
gs, _ = utils.load_graphs(os.path.join(self.save_path, 'graph_{}.bin'.format(self.hash)))
self.g = gs[0]
self._set_split_index()
def save(self):
utils.save_graphs(os.path.join(self.save_path, 'graph_{}.bin'.format(self.hash)), [self.g])
......@@ -148,6 +160,13 @@ class AsNodePredDataset(DGLDataset):
def __len__(self):
return 1
def _set_split_index(self):
"""Add train_idx/val_idx/test_idx as dataset attributes according to corresponding mask."""
ndata = self.g.nodes[self.target_ntype].data
self.train_idx = F.nonzero_1d(ndata['train_mask'])
self.val_idx = F.nonzero_1d(ndata['val_mask'])
self.test_idx = F.nonzero_1d(ndata['test_mask'])
def negative_sample(g, num_samples):
"""Random sample negative edges from graph, excluding self-loops,
......
......@@ -1196,6 +1196,12 @@ def test_as_nodepred1():
assert new_ds[0].num_nodes() == ds[0].num_nodes()
assert new_ds[0].num_edges() == ds[0].num_edges()
assert 'train_mask' in new_ds[0].ndata
assert F.array_equal(new_ds.train_idx, F.nonzero_1d(
new_ds[0].ndata['train_mask']))
assert F.array_equal(new_ds.val_idx, F.nonzero_1d(
new_ds[0].ndata['val_mask']))
assert F.array_equal(new_ds.test_idx, F.nonzero_1d(
new_ds[0].ndata['test_mask']))
ds = data.AIFBDataset()
print('train_mask' in ds[0].nodes['Personen'].data)
......@@ -1204,6 +1210,12 @@ def test_as_nodepred1():
assert new_ds[0].ntypes == ds[0].ntypes
assert new_ds[0].canonical_etypes == ds[0].canonical_etypes
assert 'train_mask' in new_ds[0].nodes['Personen'].data
assert F.array_equal(new_ds.train_idx, F.nonzero_1d(
new_ds[0].nodes['Personen'].data['train_mask']))
assert F.array_equal(new_ds.val_idx, F.nonzero_1d(
new_ds[0].nodes['Personen'].data['val_mask']))
assert F.array_equal(new_ds.test_idx, F.nonzero_1d(
new_ds[0].nodes['Personen'].data['test_mask']))
@unittest.skipIf(F._default_context_str == 'gpu', reason="Datasets don't need to be tested on GPU.")
def test_as_nodepred2():
......@@ -1212,27 +1224,38 @@ def test_as_nodepred2():
# create
ds = data.AsNodePredDataset(data.AmazonCoBuyComputerDataset(), [0.8, 0.1, 0.1])
assert F.sum(F.astype(ds[0].ndata['train_mask'], F.int32), 0) == int(ds[0].num_nodes() * 0.8)
assert len(ds.train_idx) == int(ds[0].num_nodes() * 0.8)
# read from cache
ds = data.AsNodePredDataset(data.AmazonCoBuyComputerDataset(), [0.8, 0.1, 0.1])
assert F.sum(F.astype(ds[0].ndata['train_mask'], F.int32), 0) == int(ds[0].num_nodes() * 0.8)
assert len(ds.train_idx) == int(ds[0].num_nodes() * 0.8)
# invalid cache, re-read
ds = data.AsNodePredDataset(data.AmazonCoBuyComputerDataset(), [0.1, 0.1, 0.8])
assert F.sum(F.astype(ds[0].ndata['train_mask'], F.int32), 0) == int(ds[0].num_nodes() * 0.1)
assert len(ds.train_idx) == int(ds[0].num_nodes() * 0.1)
# create
ds = data.AsNodePredDataset(data.AIFBDataset(), [0.8, 0.1, 0.1], 'Personen', verbose=True)
assert F.sum(F.astype(ds[0].nodes['Personen'].data['train_mask'], F.int32), 0) == int(ds[0].num_nodes('Personen') * 0.8)
assert len(ds.train_idx) == int(ds[0].num_nodes('Personen') * 0.8)
# read from cache
ds = data.AsNodePredDataset(data.AIFBDataset(), [0.8, 0.1, 0.1], 'Personen', verbose=True)
assert F.sum(F.astype(ds[0].nodes['Personen'].data['train_mask'], F.int32), 0) == int(ds[0].num_nodes('Personen') * 0.8)
assert len(ds.train_idx) == int(ds[0].num_nodes('Personen') * 0.8)
# invalid cache, re-read
ds = data.AsNodePredDataset(data.AIFBDataset(), [0.1, 0.1, 0.8], 'Personen', verbose=True)
assert F.sum(F.astype(ds[0].nodes['Personen'].data['train_mask'], F.int32), 0) == int(ds[0].num_nodes('Personen') * 0.1)
assert len(ds.train_idx) == int(ds[0].num_nodes('Personen') * 0.1)
@unittest.skipIf(dgl.backend.backend_name != 'pytorch', reason="ogb only supports pytorch")
def test_as_nodepred_ogb():
from ogb.nodeproppred import DglNodePropPredDataset
ds = data.AsNodePredDataset(DglNodePropPredDataset("ogbn-arxiv"), split_ratio=None, verbose=True)
split = DglNodePropPredDataset("ogbn-arxiv").get_idx_split()
train_idx, val_idx, test_idx = split['train'], split['valid'], split['test']
assert F.array_equal(ds.train_idx, F.tensor(train_idx))
assert F.array_equal(ds.val_idx, F.tensor(val_idx))
assert F.array_equal(ds.test_idx, F.tensor(test_idx))
# force generate new split
ds = data.AsNodePredDataset(DglNodePropPredDataset("ogbn-arxiv"), split_ratio=[0.7, 0.2, 0.1], verbose=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment