Unverified Commit 92e77330 authored by Huarui HE's avatar Huarui HE Committed by GitHub
Browse files

[dataset] Add a `reorder` flag to builtin datasets (#4104)



* add argument reorder=False for citation_graph

* add description of the argument reorder

* add reordered/un_reordered save_path

* add version number postfix
Co-authored-by: default avatarMufei Li <mufeili1996@gmail.com>
parent 148575e4
......@@ -51,6 +51,8 @@ class CitationGraphDataset(DGLBuiltinDataset):
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
reorder : bool
Whether to reorder the graph using :func:`~dgl.reorder_graph`. Default: False.
"""
_urls = {
'cora_v2' : 'dataset/cora_v2.zip',
......@@ -59,7 +61,8 @@ class CitationGraphDataset(DGLBuiltinDataset):
}
def __init__(self, name, raw_dir=None, force_reload=False,
verbose=True, reverse_edge=True, transform=None):
verbose=True, reverse_edge=True, transform=None,
reorder=False):
assert name.lower() in ['cora', 'citeseer', 'pubmed']
# Previously we use the pre-processing in pygcn (https://github.com/tkipf/pygcn)
......@@ -69,6 +72,7 @@ class CitationGraphDataset(DGLBuiltinDataset):
url = _get_dgl_url(self._urls[name])
self._reverse_edge = reverse_edge
self._reorder = reorder
super(CitationGraphDataset, self).__init__(name,
url=url,
......@@ -143,8 +147,11 @@ class CitationGraphDataset(DGLBuiltinDataset):
g.ndata['feat'] = F.tensor(_preprocess_features(features), dtype=F.data_type_dict['float32'])
self._num_classes = onehot_labels.shape[1]
self._labels = labels
self._g = reorder_graph(
g, node_permute_algo='rcmk', edge_permute_algo='dst', store_ids=False)
if self._reorder:
self._g = reorder_graph(
g, node_permute_algo='rcmk', edge_permute_algo='dst', store_ids=False)
else:
self._g = g
if self.verbose:
print('Finished data loading and preprocessing.')
......@@ -373,6 +380,8 @@ class CoraGraphDataset(CitationGraphDataset):
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
reorder : bool
Whether to reorder the graph using :func:`~dgl.reorder_graph`. Default: False.
Attributes
----------
......@@ -414,11 +423,11 @@ class CoraGraphDataset(CitationGraphDataset):
"""
def __init__(self, raw_dir=None, force_reload=False, verbose=True,
reverse_edge=True, transform=None):
reverse_edge=True, transform=None, reorder=False):
name = 'cora'
super(CoraGraphDataset, self).__init__(name, raw_dir, force_reload,
verbose, reverse_edge, transform)
verbose, reverse_edge, transform, reorder)
def __getitem__(self, idx):
r"""Gets the graph object
......@@ -519,6 +528,8 @@ class CiteseerGraphDataset(CitationGraphDataset):
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
reorder : bool
Whether to reorder the graph using :func:`~dgl.reorder_graph`. Default: False.
Attributes
----------
......@@ -563,11 +574,11 @@ class CiteseerGraphDataset(CitationGraphDataset):
"""
def __init__(self, raw_dir=None, force_reload=False,
verbose=True, reverse_edge=True, transform=None):
verbose=True, reverse_edge=True, transform=None, reorder=False):
name = 'citeseer'
super(CiteseerGraphDataset, self).__init__(name, raw_dir, force_reload,
verbose, reverse_edge, transform)
verbose, reverse_edge, transform, reorder)
def __getitem__(self, idx):
r"""Gets the graph object
......@@ -668,6 +679,8 @@ class PubmedGraphDataset(CitationGraphDataset):
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
reorder : bool
Whether to reorder the graph using :func:`~dgl.reorder_graph`. Default: False.
Attributes
----------
......@@ -709,11 +722,11 @@ class PubmedGraphDataset(CitationGraphDataset):
"""
def __init__(self, raw_dir=None, force_reload=False, verbose=True,
reverse_edge=True, transform=None):
reverse_edge=True, transform=None, reorder=False):
name = 'pubmed'
super(PubmedGraphDataset, self).__init__(name, raw_dir, force_reload,
verbose, reverse_edge, transform)
verbose, reverse_edge, transform, reorder)
def __getitem__(self, idx):
r"""Gets the graph object
......
......@@ -8,6 +8,7 @@ import traceback
import abc
from .utils import download, extract_archive, get_download_dir, makedirs
from ..utils import retry_method_with_fix
from .._ffi.base import __version__
class DGLDataset(object):
r"""The basic DGL dataset for creating graph datasets.
......@@ -237,13 +238,17 @@ class DGLDataset(object):
def save_dir(self):
r"""Directory to save the processed dataset.
"""
return self._save_dir
return self._save_dir + "_v{}".format(__version__)
@property
def save_path(self):
r"""Path to save the processed dataset.
"""
return os.path.join(self._save_dir, self.name)
if hasattr(self, '_reorder'):
path = 'reordered' if self._reorder else 'un_reordered'
return os.path.join(self._save_dir, self.name, path)
else:
return os.path.join(self._save_dir, self.name)
@property
def verbose(self):
......
......@@ -17,6 +17,7 @@ from .graph_serialize import save_graphs, load_graphs, load_labels
from .tensor_serialize import save_tensors, load_tensors
from .. import backend as F
from .._ffi.base import __version__
__all__ = ['loadtxt','download', 'check_sha1', 'extract_archive',
'get_download_dir', 'Subset', 'split_dataset', 'save_graphs',
......@@ -240,7 +241,7 @@ def get_download_dir():
dirname : str
Path to the download directory
"""
default_dir = os.path.join(os.path.expanduser('~'), '.dgl')
default_dir = os.path.join(os.path.expanduser('~'), '.dgl_v{}'.format(__version__))
dirname = os.environ.get('DGL_DOWNLOAD_DIR', default_dir)
if not os.path.exists(dirname):
os.makedirs(dirname)
......
......@@ -115,7 +115,7 @@ def test_citation_graph():
transform = dgl.AddSelfLoop(allow_duplicate=True)
# cora
g = data.CoraGraphDataset()[0]
g = data.CoraGraphDataset(force_reload=True, reorder=True)[0]
assert g.num_nodes() == 2708
assert g.num_edges() == 10556
dst = F.asnumpy(g.edges()[1])
......@@ -124,7 +124,7 @@ def test_citation_graph():
assert g2.num_edges() - g.num_edges() == g.num_nodes()
# Citeseer
g = data.CiteseerGraphDataset()[0]
g = data.CiteseerGraphDataset(force_reload=True, reorder=True)[0]
assert g.num_nodes() == 3327
assert g.num_edges() == 9228
dst = F.asnumpy(g.edges()[1])
......@@ -133,7 +133,7 @@ def test_citation_graph():
assert g2.num_edges() - g.num_edges() == g.num_nodes()
# Pubmed
g = data.PubmedGraphDataset()[0]
g = data.PubmedGraphDataset(force_reload=True, reorder=True)[0]
assert g.num_nodes() == 19717
assert g.num_edges() == 88651
dst = F.asnumpy(g.edges()[1])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment