"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "5883d8d4d1ea14cfb29433f1039ecf20f8afd777"
Unverified Commit 5d5436ba authored by Xiangkun Hu's avatar Xiangkun Hu Committed by GitHub
Browse files

[Dataset] RedditDataset change data.train_mask to numpy array (#1961)



* PPIDataset

* Revert "PPIDataset"

This reverts commit 264bd0c960cfa698a7bb946dad132bf52c2d0c8a.

* Update reddit.py
Co-authored-by: default avatarxiang song(charlie.song) <classicxsong@gmail.com>
parent 35c9473b
...@@ -94,11 +94,11 @@ class RedditDataset(DGLBuiltinDataset): ...@@ -94,11 +94,11 @@ class RedditDataset(DGLBuiltinDataset):
Graph of the dataset Graph of the dataset
num_labels : int num_labels : int
Number of classes for each node Number of classes for each node
train_mask: Tensor train_mask: numpy.ndarray
Mask of training nodes Mask of training nodes
val_mask: Tensor val_mask: numpy.ndarray
Mask of validation nodes Mask of validation nodes
test_mask: Tensor test_mask: numpy.ndarray
Mask of test nodes Mask of test nodes
features : Tensor features : Tensor
Node features Node features
...@@ -202,17 +202,17 @@ class RedditDataset(DGLBuiltinDataset): ...@@ -202,17 +202,17 @@ class RedditDataset(DGLBuiltinDataset):
@property @property
def train_mask(self): def train_mask(self):
deprecate_property('dataset.train_mask', 'graph.ndata[\'train_mask\']') deprecate_property('dataset.train_mask', 'graph.ndata[\'train_mask\']')
return self._graph.ndata['train_mask'] return F.asnumpy(self._graph.ndata['train_mask'])
@property @property
def val_mask(self): def val_mask(self):
deprecate_property('dataset.val_mask', 'graph.ndata[\'val_mask\']') deprecate_property('dataset.val_mask', 'graph.ndata[\'val_mask\']')
return self._graph.ndata['val_mask'] return F.asnumpy(self._graph.ndata['val_mask'])
@property @property
def test_mask(self): def test_mask(self):
deprecate_property('dataset.test_mask', 'graph.ndata[\'test_mask\']') deprecate_property('dataset.test_mask', 'graph.ndata[\'test_mask\']')
return self._graph.ndata['test_mask'] return F.asnumpy(self._graph.ndata['test_mask'])
@property @property
def features(self): def features(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment