Unverified Commit 5d5436ba authored by Xiangkun Hu's avatar Xiangkun Hu Committed by GitHub
Browse files

[Dataset] RedditDataset change data.train_mask to numpy array (#1961)



* PPIDataset

* Revert "PPIDataset"

This reverts commit 264bd0c960cfa698a7bb946dad132bf52c2d0c8a.

* Update reddit.py
Co-authored-by: default avatarxiang song(charlie.song) <classicxsong@gmail.com>
parent 35c9473b
......@@ -94,11 +94,11 @@ class RedditDataset(DGLBuiltinDataset):
Graph of the dataset
num_labels : int
Number of classes for each node
train_mask: Tensor
train_mask: numpy.ndarray
Mask of training nodes
val_mask: Tensor
val_mask: numpy.ndarray
Mask of validation nodes
test_mask: Tensor
test_mask: numpy.ndarray
Mask of test nodes
features : Tensor
Node features
......@@ -202,17 +202,17 @@ class RedditDataset(DGLBuiltinDataset):
@property
def train_mask(self):
deprecate_property('dataset.train_mask', 'graph.ndata[\'train_mask\']')
return self._graph.ndata['train_mask']
return F.asnumpy(self._graph.ndata['train_mask'])
@property
def val_mask(self):
deprecate_property('dataset.val_mask', 'graph.ndata[\'val_mask\']')
return self._graph.ndata['val_mask']
return F.asnumpy(self._graph.ndata['val_mask'])
@property
def test_mask(self):
deprecate_property('dataset.test_mask', 'graph.ndata[\'test_mask\']')
return self._graph.ndata['test_mask']
return F.asnumpy(self._graph.ndata['test_mask'])
@property
def features(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment