Commit 51293a32 authored by Ziyue Huang's avatar Ziyue Huang Committed by Da Zheng
Browse files

[Data] reddit data loader (#372)

* reddit data loader

* upload to S3

* update

* add self loop

* address comments

* fi
parent 5b9147c4
...@@ -7,6 +7,7 @@ from .minigc import * ...@@ -7,6 +7,7 @@ from .minigc import *
from .tree import * from .tree import *
from .utils import * from .utils import *
from .sbm import SBMMixture from .sbm import SBMMixture
from .reddit import RedditDataset
def register_data_args(parser): def register_data_args(parser):
parser.add_argument("--dataset", type=str, required=False, parser.add_argument("--dataset", type=str, required=False,
...@@ -22,5 +23,7 @@ def load_data(args): ...@@ -22,5 +23,7 @@ def load_data(args):
return citegrh.load_pubmed() return citegrh.load_pubmed()
elif args.dataset == 'syn': elif args.dataset == 'syn':
return citegrh.load_synthetic(args) return citegrh.load_synthetic(args)
elif args.dataset.startswith('reddit'):
return RedditDataset(self_loop=('self-loop' in args.dataset))
else: else:
raise ValueError('Unknown dataset: {}'.format(args.dataset)) raise ValueError('Unknown dataset: {}'.format(args.dataset))
from __future__ import absolute_import
import scipy.sparse as sp
import numpy as np
import dgl
import os, sys
from ..graph_index import create_graph_index
from .utils import download, extract_archive, get_download_dir, _get_dgl_url
class RedditDataset(object):
def __init__(self, self_loop=False):
download_dir = get_download_dir()
self_loop_str = ""
if self_loop:
self_loop_str = "_self_loop"
zip_file_path = os.path.join(download_dir, "reddit{}.zip".format(self_loop_str))
download(_get_dgl_url("dataset/reddit{}.zip".format(self_loop_str)), path=zip_file_path)
extract_dir = os.path.join(download_dir, "reddit{}".format(self_loop_str))
extract_archive(zip_file_path, extract_dir)
# graph
coo_adj = sp.load_npz(os.path.join(extract_dir, "reddit{}_graph.npz".format(self_loop_str)))
self.graph = create_graph_index(coo_adj, readonly=True)
# features and labels
reddit_data = np.load(os.path.join(extract_dir, "reddit_data.npz"))
self.features = reddit_data["feature"]
self.labels = reddit_data["label"]
self.num_labels = 41
# tarin/val/test indices
node_ids = reddit_data["node_ids"]
node_types = reddit_data["node_types"]
self.train_mask = (node_types == 1)
self.val_mask = (node_types == 2)
self.test_mask = (node_types == 3)
print('Finished data loading.')
print(' NumNodes: {}'.format(self.graph.number_of_nodes()))
print(' NumEdges: {}'.format(self.graph.number_of_edges()))
print(' NumFeats: {}'.format(self.features.shape[1]))
print(' NumClasses: {}'.format(self.num_labels))
print(' NumTrainingSamples: {}'.format(len(np.nonzero(self.train_mask)[0])))
print(' NumValidationSamples: {}'.format(len(np.nonzero(self.val_mask)[0])))
print(' NumTestSamples: {}'.format(len(np.nonzero(self.test_mask)[0])))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment