Unverified Commit 2c141229 authored by Xiangkun Hu's avatar Xiangkun Hu Committed by GitHub
Browse files

[Dataset] RedditDataset (#1914)



* PPIDataset

* Revert "PPIDataset"

This reverts commit 264bd0c960cfa698a7bb946dad132bf52c2d0c8a.

* RedditDataset

* Update reddit.py
Co-authored-by: default avatarxiang song(charlie.song) <classicxsong@gmail.com>
parent 37aa99c5
...@@ -6,18 +6,9 @@ def load_reddit(): ...@@ -6,18 +6,9 @@ def load_reddit():
# load reddit data # load reddit data
data = RedditDataset(self_loop=True) data = RedditDataset(self_loop=True)
train_mask = data.train_mask g = data[0]
val_mask = data.val_mask g.ndata['features'] = g.ndata['feat']
features = th.Tensor(data.features) g.ndata['labels'] = g.ndata['label']
labels = th.LongTensor(data.labels)
# Construct graph
g = data.graph
g.ndata['features'] = features
g.ndata['labels'] = labels
g.ndata['train_mask'] = th.BoolTensor(data.train_mask)
g.ndata['val_mask'] = th.BoolTensor(data.val_mask)
g.ndata['test_mask'] = th.BoolTensor(data.test_mask)
return g, data.num_labels return g, data.num_labels
def load_ogb(name): def load_ogb(name):
......
...@@ -324,14 +324,13 @@ if __name__ == '__main__': ...@@ -324,14 +324,13 @@ if __name__ == '__main__':
# load reddit data # load reddit data
data = RedditDataset(self_loop=True) data = RedditDataset(self_loop=True)
train_mask = data.train_mask n_classes = data.num_classes
val_mask = data.val_mask g = data[0]
features = th.Tensor(data.features) features = g.ndata['feat']
in_feats = features.shape[1] in_feats = features.shape[1]
labels = th.LongTensor(data.labels) labels = g.ndata['label']
n_classes = data.num_labels train_mask = g.ndata['train_mask']
# Construct graph val_mask = g.ndata['val_mask']
g = dgl.graph(data.graph.all_edges())
g.ndata['features'] = features g.ndata['features'] = features
prepare_mp(g) prepare_mp(g)
# Pack data # Pack data
......
...@@ -387,14 +387,13 @@ if __name__ == '__main__': ...@@ -387,14 +387,13 @@ if __name__ == '__main__':
# load reddit data # load reddit data
data = RedditDataset(self_loop=True) data = RedditDataset(self_loop=True)
train_mask = data.train_mask n_classes = data.num_classes
val_mask = data.val_mask g = data[0]
features = th.Tensor(data.features) features = g.ndata['feat']
in_feats = features.shape[1] in_feats = features.shape[1]
labels = th.LongTensor(data.labels) labels = g.ndata['label']
n_classes = data.num_labels train_mask = g.ndata['train_mask']
# Construct graph val_mask = g.ndata['val_mask']
g = dgl.graph(data.graph.all_edges())
g.ndata['features'] = features.share_memory_() g.ndata['features'] = features.share_memory_()
create_history_storage(g, args, n_classes) create_history_storage(g, args, n_classes)
......
...@@ -331,15 +331,13 @@ def run(proc_id, n_gpus, args, devices, data): ...@@ -331,15 +331,13 @@ def run(proc_id, n_gpus, args, devices, data):
def main(args, devices): def main(args, devices):
# load reddit data # load reddit data
data = RedditDataset(self_loop=True) data = RedditDataset(self_loop=True)
train_mask = data.train_mask n_classes = data.num_classes
val_mask = data.val_mask g = data[0]
test_mask = data.test_mask features = g.ndata['feat']
features = th.Tensor(data.features)
in_feats = features.shape[1] in_feats = features.shape[1]
labels = th.LongTensor(data.labels) labels = g.ndata['label']
n_classes = data.num_labels train_mask = g.ndata['train_mask']
# Construct graph val_mask = g.ndata['val_mask']
g = dgl.graph(data.graph.all_edges())
g.ndata['features'] = features g.ndata['features'] = features
# Pack data # Pack data
data = train_mask, val_mask, test_mask, in_feats, labels, n_classes, g data = train_mask, val_mask, test_mask, in_feats, labels, n_classes, g
......
""" Reddit dataset for community detection """
from __future__ import absolute_import from __future__ import absolute_import
import scipy.sparse as sp import scipy.sparse as sp
import numpy as np import numpy as np
import os, sys import os
from .utils import download, extract_archive, get_download_dir, _get_dgl_url
from ..utils import retry_method_with_fix from .dgl_dataset import DGLBuiltinDataset
from .utils import _get_dgl_url, generate_mask_tensor, load_graphs, save_graphs, deprecate_property
from .. import backend as F from .. import backend as F
from .. import convert from ..convert import graph as dgl_graph
class RedditDataset(DGLBuiltinDataset):
r""" Reddit dataset for community detection (node classification)
.. deprecated:: 0.5.0
`graph` is deprecated, it is replaced by:
>>> dataset = RedditDataset()
>>> graph = dataset[0]
`num_labels` is deprecated, it is replaced by:
>>> dataset = RedditDataset()
>>> num_classes = dataset.num_classes
`train_mask` is deprecated, it is replaced by:
>>> dataset = RedditDataset()
>>> graph = dataset[0]
>>> train_mask = graph.ndata['train_mask']
`val_mask` is deprecated, it is replaced by:
>>> dataset = RedditDataset()
>>> graph = dataset[0]
>>> val_mask = graph.ndata['val_mask']
`test_mask` is deprecated, it is replaced by:
>>> dataset = RedditDataset()
>>> graph = dataset[0]
>>> test_mask = graph.ndata['test_mask']
`features` is deprecated, it is replaced by:
>>> dataset = RedditDataset()
>>> graph = dataset[0]
>>> features = graph.ndata['feat']
`labels` is deprecated, it is replaced by:
>>> dataset = RedditDataset()
>>> graph = dataset[0]
>>> labels = graph.ndata['label']
This is a graph dataset from Reddit posts made in the month of September, 2014.
The node label in this case is the community, or “subreddit”, that a post belongs to.
The authors sampled 50 large communities and built a post-to-post graph, connecting
posts if the same user comments on both. In total this dataset contains 232,965
posts with an average degree of 492. We use the first 20 days for training and the
remaining days for testing (with 30% used for validation).
Reference: http://snap.stanford.edu/graphsage/
Statistics
----------
Nodes: 232,965
Edges: 114,615,892
Node feature size: 602
Number of training samples: 153,431
Number of validation samples: 23,831
Number of test samples: 55,703
class RedditDataset(object): Parameters
def __init__(self, self_loop=False): ----------
download_dir = get_download_dir() self_loop : bool
Whether load dataset with self loop connections. Default: False
raw_dir : str
Raw file directory to download/contains the input data directory.
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
Whether to print out progress information. Default: True.
Attributes
----------
num_classes : int
Number of classes for each node
graph : dgl.DGLGraph
Graph of the dataset
num_labels : int
Number of classes for each node
train_mask: Tensor
Mask of training nodes
val_mask: Tensor
Mask of validation nodes
test_mask: Tensor
Mask of test nodes
features : Tensor
Node features
labels : Tensor
Node labels
Examples
--------
>>> data = RedditDataset()
>>> g = data[0]
>>> num_classes = data.num_classes
>>>
>>> # get node feature
>>> feat = g.ndata['feat']
>>>
>>> # get data split
>>> train_mask = g.ndata['train_mask']
>>> val_mask = g.ndata['val_mask']
>>> test_mask = g.ndata['test_mask']
>>>
>>> # get labels
>>> label = g.ndata['label']
>>>
>>> # Train, Validation and Test
"""
def __init__(self, self_loop=False, raw_dir=None, force_reload=False, verbose=False):
self_loop_str = "" self_loop_str = ""
if self_loop: if self_loop:
self_loop_str = "_self_loop" self_loop_str = "_self_loop"
zip_file_path = os.path.join(download_dir, "reddit{}.zip".format(self_loop_str)) _url = _get_dgl_url("dataset/reddit{}.zip".format(self_loop_str))
extract_dir = os.path.join(download_dir, "reddit{}".format(self_loop_str))
self._url = _get_dgl_url("dataset/reddit{}.zip".format(self_loop_str))
self._zip_file_path = zip_file_path
self._extract_dir = extract_dir
self._self_loop_str = self_loop_str self._self_loop_str = self_loop_str
self._load() super(RedditDataset, self).__init__(name='reddit{}'.format(self_loop_str),
url=_url,
def _download(self): raw_dir=raw_dir,
download(self._url, path=self._zip_file_path) force_reload=force_reload,
extract_archive(self._zip_file_path, self._extract_dir) verbose=verbose)
@retry_method_with_fix(_download) def process(self):
def _load(self):
# graph # graph
coo_adj = sp.load_npz(os.path.join( coo_adj = sp.load_npz(os.path.join(
self._extract_dir, "reddit{}_graph.npz".format(self._self_loop_str))) self.raw_path, "reddit{}_graph.npz".format(self._self_loop_str)))
self.graph = convert.graph(coo_adj) self._graph = dgl_graph(coo_adj)
# features and labels # features and labels
reddit_data = np.load(os.path.join(self._extract_dir, "reddit_data.npz")) reddit_data = np.load(os.path.join(self.raw_path, "reddit_data.npz"))
self.features = reddit_data["feature"] features = reddit_data["feature"]
self.labels = reddit_data["label"] labels = reddit_data["label"]
self.num_labels = 41
# tarin/val/test indices # tarin/val/test indices
node_types = reddit_data["node_types"] node_types = reddit_data["node_types"]
self.train_mask = (node_types == 1) train_mask = (node_types == 1)
self.val_mask = (node_types == 2) val_mask = (node_types == 2)
self.test_mask = (node_types == 3) test_mask = (node_types == 3)
self._graph.ndata['train_mask'] = generate_mask_tensor(train_mask)
print('Finished data loading.') self._graph.ndata['val_mask'] = generate_mask_tensor(val_mask)
print(' NumNodes: {}'.format(self.graph.number_of_nodes())) self._graph.ndata['test_mask'] = generate_mask_tensor(test_mask)
print(' NumEdges: {}'.format(self.graph.number_of_edges())) self._graph.ndata['feat'] = F.tensor(features, dtype=F.data_type_dict['float32'])
print(' NumFeats: {}'.format(self.features.shape[1])) self._graph.ndata['label'] = F.tensor(labels, dtype=F.data_type_dict['int64'])
print(' NumClasses: {}'.format(self.num_labels)) self._print_info()
print(' NumTrainingSamples: {}'.format(len(np.nonzero(self.train_mask)[0])))
print(' NumValidationSamples: {}'.format(len(np.nonzero(self.val_mask)[0]))) def has_cache(self):
print(' NumTestSamples: {}'.format(len(np.nonzero(self.test_mask)[0]))) graph_path = os.path.join(self.save_path, 'dgl_graph.bin')
if os.path.exists(graph_path):
return True
return False
def save(self):
graph_path = os.path.join(self.save_path, 'dgl_graph.bin')
save_graphs(graph_path, self._graph)
def load(self):
graph_path = os.path.join(self.save_path, 'dgl_graph.bin')
graphs, _ = load_graphs(graph_path)
self._graph = graphs[0]
self._print_info()
def _print_info(self):
if self.verbose:
print('Finished data loading.')
print(' NumNodes: {}'.format(self._graph.number_of_nodes()))
print(' NumEdges: {}'.format(self._graph.number_of_edges()))
print(' NumFeats: {}'.format(self._graph.ndata['feat'].shape[1]))
print(' NumClasses: {}'.format(self.num_classes))
print(' NumTrainingSamples: {}'.format(F.nonzero_1d(self._graph.ndata['train_mask']).shape[0]))
print(' NumValidationSamples: {}'.format(F.nonzero_1d(self._graph.ndata['val_mask']).shape[0]))
print(' NumTestSamples: {}'.format(F.nonzero_1d(self._graph.ndata['test_mask']).shape[0]))
@property
def num_classes(self):
r"""Number of classes for each node."""
return 41
@property
def num_labels(self):
deprecate_property('dataset.num_labels', 'dataset.num_classes')
return self.num_classes
@property
def graph(self):
deprecate_property('dataset.graph', 'dataset[0]')
return self._graph
@property
def train_mask(self):
deprecate_property('dataset.train_mask', 'graph.ndata[\'train_mask\']')
return self._graph.ndata['train_mask']
@property
def val_mask(self):
deprecate_property('dataset.val_mask', 'graph.ndata[\'val_mask\']')
return self._graph.ndata['val_mask']
@property
def test_mask(self):
deprecate_property('dataset.test_mask', 'graph.ndata[\'test_mask\']')
return self._graph.ndata['test_mask']
@property
def features(self):
deprecate_property('dataset.features', 'graph.ndata[\'feat\']')
return self._graph.ndata['feat']
@property
def labels(self):
deprecate_property('dataset.labels', 'graph.ndata[\'label\']')
return self._graph.ndata['label']
def __getitem__(self, idx): def __getitem__(self, idx):
r""" Get graph by index
Parameters
----------
idx : int
Item index
Returns
-------
dgl.DGLGraph
graph structure, node labels, node features and splitting masks
- ndata['label']: node label
- ndata['feat']: node feature
- ndata['train_mask']: mask for training node set
- ndata['val_mask']: mask for validation node set
- ndata['test_mask']: mask for test node set
"""
assert idx == 0, "Reddit Dataset only has one graph" assert idx == 0, "Reddit Dataset only has one graph"
self.graph.ndata['train_mask'] = F.tensor(self.train_mask, dtype=F.bool) return self._graph
self.graph.ndata['val_mask'] = F.tensor(self.val_mask, dtype=F.bool)
self.graph.ndata['test_mask'] = F.tensor(self.test_mask, dtype=F.bool)
self.graph.ndata['feat'] = F.tensor(self.features, dtype=F.float32)
self.graph.ndata['label'] = F.tensor(self.labels, dtype=F.int64)
return self.graph
def __len__(self): def __len__(self):
r"""Number of graphs in the dataset"""
return 1 return 1
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment