Unverified Commit 186ef592 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[Feature] apply dgl.reorder() onto several node classification datase… (#3102)

* [Feature] apply dgl.reorder() onto several node classification datasets in DGL

* rebase on latest dgl.reorder_graph()
parent 175f53de
...@@ -133,6 +133,13 @@ a single graph. As such, splits of the dataset are on the nodes of the ...@@ -133,6 +133,13 @@ a single graph. As such, splits of the dataset are on the nodes of the
graph. DGL recommends using node masks to specify the splits. The section uses graph. DGL recommends using node masks to specify the splits. The section uses
builtin dataset `CitationGraphDataset <https://docs.dgl.ai/en/0.5.x/_modules/dgl/data/citation_graph.html#CitationGraphDataset>`__ as an example: builtin dataset `CitationGraphDataset <https://docs.dgl.ai/en/0.5.x/_modules/dgl/data/citation_graph.html#CitationGraphDataset>`__ as an example:
In addition, DGL recommends re-arrange the nodes and edges so that nodes
near to each other have IDs in a close range. The procedure could improve
the locality to access a node's neighbors, which may benefit follow-up
computation and analysis conducted on the graph. DGL provides an API called
:func:`dgl.reorder_graph` for this purpose. Please refer to ``process()``
part in below example for more details.
.. code:: .. code::
from dgl.data import DGLBuiltinDataset from dgl.data import DGLBuiltinDataset
...@@ -173,7 +180,8 @@ builtin dataset `CitationGraphDataset <https://docs.dgl.ai/en/0.5.x/_modules/dgl ...@@ -173,7 +180,8 @@ builtin dataset `CitationGraphDataset <https://docs.dgl.ai/en/0.5.x/_modules/dgl
dtype=F.data_type_dict['float32']) dtype=F.data_type_dict['float32'])
self._num_labels = onehot_labels.shape[1] self._num_labels = onehot_labels.shape[1]
self._labels = labels self._labels = labels
self._g = g # reorder graph to obtain better locality.
self._g = dgl.reorder_graph(g)
def __getitem__(self, idx): def __getitem__(self, idx):
assert idx == 0, "This dataset has only one graph" assert idx == 0, "This dataset has only one graph"
......
...@@ -116,6 +116,11 @@ DGL建议让 ``__getitem__(idx)`` 返回如上面代码所示的元组 ``(图, ...@@ -116,6 +116,11 @@ DGL建议让 ``__getitem__(idx)`` 返回如上面代码所示的元组 ``(图,
DGL建议使用节点掩码来指定数据集的划分。 DGL建议使用节点掩码来指定数据集的划分。
本节以内置数据集 `CitationGraphDataset <https://docs.dgl.ai/en/0.5.x/_modules/dgl/data/citation_graph.html#CitationGraphDataset>`__ 为例: 本节以内置数据集 `CitationGraphDataset <https://docs.dgl.ai/en/0.5.x/_modules/dgl/data/citation_graph.html#CitationGraphDataset>`__ 为例:
此外,DGL推荐重新排列图的节点/边,使得相邻节点/边的ID位于邻近区间内。这个过程
可以提高节点/边的邻居的局部性,为后续在图上进行的计算与分析的性能改善提供可能。
DGL提供了名为 :func:`dgl.reorder_graph` 的API用于此优化。更多细节,请参考
下面例子中的 ``process()`` 的部分。
.. code:: .. code::
from dgl.data import DGLBuiltinDataset from dgl.data import DGLBuiltinDataset
...@@ -159,7 +164,8 @@ DGL建议使用节点掩码来指定数据集的划分。 ...@@ -159,7 +164,8 @@ DGL建议使用节点掩码来指定数据集的划分。
dtype=F.data_type_dict['float32']) dtype=F.data_type_dict['float32'])
self._num_labels = onehot_labels.shape[1] self._num_labels = onehot_labels.shape[1]
self._labels = labels self._labels = labels
self._g = g # 重排图以获得更优的局部性
self._g = dgl.reorder_graph(g)
def __getitem__(self, idx): def __getitem__(self, idx):
assert idx == 0, "这个数据集里只有一个图" assert idx == 0, "这个数据集里只有一个图"
......
...@@ -20,6 +20,7 @@ from .. import batch ...@@ -20,6 +20,7 @@ from .. import batch
from .. import backend as F from .. import backend as F
from ..convert import graph as dgl_graph from ..convert import graph as dgl_graph
from ..convert import from_networkx, to_networkx from ..convert import from_networkx, to_networkx
from ..transform import reorder_graph
backend = os.environ.get('DGLBACKEND', 'pytorch') backend = os.environ.get('DGLBACKEND', 'pytorch')
...@@ -71,7 +72,7 @@ class CitationGraphDataset(DGLBuiltinDataset): ...@@ -71,7 +72,7 @@ class CitationGraphDataset(DGLBuiltinDataset):
verbose=verbose) verbose=verbose)
def process(self): def process(self):
"""Loads input data from data directory """Loads input data from data directory and reorder graph for better locality
ind.name.x => the feature vectors of the training instances as scipy.sparse.csr.csr_matrix object; ind.name.x => the feature vectors of the training instances as scipy.sparse.csr.csr_matrix object;
ind.name.tx => the feature vectors of the test instances as scipy.sparse.csr.csr_matrix object; ind.name.tx => the feature vectors of the test instances as scipy.sparse.csr.csr_matrix object;
...@@ -136,7 +137,8 @@ class CitationGraphDataset(DGLBuiltinDataset): ...@@ -136,7 +137,8 @@ class CitationGraphDataset(DGLBuiltinDataset):
g.ndata['feat'] = F.tensor(_preprocess_features(features), dtype=F.data_type_dict['float32']) g.ndata['feat'] = F.tensor(_preprocess_features(features), dtype=F.data_type_dict['float32'])
self._num_classes = onehot_labels.shape[1] self._num_classes = onehot_labels.shape[1]
self._labels = labels self._labels = labels
self._g = g self._g = reorder_graph(
g, node_permute_algo='rcmk', edge_permute_algo='dst', store_ids=False)
if self.verbose: if self.verbose:
print('Finished data loading and preprocessing.') print('Finished data loading and preprocessing.')
......
...@@ -38,6 +38,8 @@ class GNNBenchmarkDataset(DGLBuiltinDataset): ...@@ -38,6 +38,8 @@ class GNNBenchmarkDataset(DGLBuiltinDataset):
def process(self): def process(self):
npz_path = os.path.join(self.raw_path, self.name + '.npz') npz_path = os.path.join(self.raw_path, self.name + '.npz')
g = self._load_npz(npz_path) g = self._load_npz(npz_path)
g = transform.reorder_graph(
g, node_permute_algo='rcmk', edge_permute_algo='dst', store_ids=False)
self._graph = g self._graph = g
self._data = [g] self._data = [g]
self._print_info() self._print_info()
......
...@@ -9,6 +9,7 @@ from .dgl_dataset import DGLBuiltinDataset ...@@ -9,6 +9,7 @@ from .dgl_dataset import DGLBuiltinDataset
from .utils import _get_dgl_url, generate_mask_tensor, load_graphs, save_graphs, deprecate_property from .utils import _get_dgl_url, generate_mask_tensor, load_graphs, save_graphs, deprecate_property
from .. import backend as F from .. import backend as F
from ..convert import from_scipy from ..convert import from_scipy
from ..transform import reorder_graph
class RedditDataset(DGLBuiltinDataset): class RedditDataset(DGLBuiltinDataset):
...@@ -155,6 +156,9 @@ class RedditDataset(DGLBuiltinDataset): ...@@ -155,6 +156,9 @@ class RedditDataset(DGLBuiltinDataset):
self._graph.ndata['test_mask'] = generate_mask_tensor(test_mask) self._graph.ndata['test_mask'] = generate_mask_tensor(test_mask)
self._graph.ndata['feat'] = F.tensor(features, dtype=F.data_type_dict['float32']) self._graph.ndata['feat'] = F.tensor(features, dtype=F.data_type_dict['float32'])
self._graph.ndata['label'] = F.tensor(labels, dtype=F.data_type_dict['int64']) self._graph.ndata['label'] = F.tensor(labels, dtype=F.data_type_dict['int64'])
self._graph = reorder_graph(
self._graph, node_permute_algo='rcmk', edge_permute_algo='dst', store_ids=False)
self._print_info() self._print_info()
def has_cache(self): def has_cache(self):
......
import dgl.data as data import dgl.data as data
import unittest import unittest
import backend as F import backend as F
import numpy as np
@unittest.skipIf(F._default_context_str == 'gpu', reason="Datasets don't need to be tested on GPU.") @unittest.skipIf(F._default_context_str == 'gpu', reason="Datasets don't need to be tested on GPU.")
...@@ -66,6 +67,78 @@ def test_data_hash(): ...@@ -66,6 +67,78 @@ def test_data_hash():
assert a.hash != c.hash assert a.hash != c.hash
@unittest.skipIf(F._default_context_str == 'gpu', reason="Datasets don't need to be tested on GPU.")
def test_citation_graph():
# cora
g = data.CoraGraphDataset()[0]
assert g.num_nodes() == 2708
assert g.num_edges() == 10556
dst = F.asnumpy(g.edges()[1])
assert np.array_equal(dst, np.sort(dst))
# Citeseer
g = data.CiteseerGraphDataset()[0]
assert g.num_nodes() == 3327
assert g.num_edges() == 9228
dst = F.asnumpy(g.edges()[1])
assert np.array_equal(dst, np.sort(dst))
# Pubmed
g = data.PubmedGraphDataset()[0]
assert g.num_nodes() == 19717
assert g.num_edges() == 88651
dst = F.asnumpy(g.edges()[1])
assert np.array_equal(dst, np.sort(dst))
@unittest.skipIf(F._default_context_str == 'gpu', reason="Datasets don't need to be tested on GPU.")
def test_gnn_benchmark():
# AmazonCoBuyComputerDataset
g = data.AmazonCoBuyComputerDataset()[0]
assert g.num_nodes() == 13752
assert g.num_edges() == 491722
dst = F.asnumpy(g.edges()[1])
assert np.array_equal(dst, np.sort(dst))
# AmazonCoBuyPhotoDataset
g = data.AmazonCoBuyPhotoDataset()[0]
assert g.num_nodes() == 7650
assert g.num_edges() == 238163
dst = F.asnumpy(g.edges()[1])
assert np.array_equal(dst, np.sort(dst))
# CoauthorPhysicsDataset
g = data.CoauthorPhysicsDataset()[0]
assert g.num_nodes() == 34493
assert g.num_edges() == 495924
dst = F.asnumpy(g.edges()[1])
assert np.array_equal(dst, np.sort(dst))
# CoauthorCSDataset
g = data.CoauthorCSDataset()[0]
assert g.num_nodes() == 18333
assert g.num_edges() == 163788
dst = F.asnumpy(g.edges()[1])
assert np.array_equal(dst, np.sort(dst))
# CoraFullDataset
g = data.CoraFullDataset()[0]
assert g.num_nodes() == 19793
assert g.num_edges() == 126842
dst = F.asnumpy(g.edges()[1])
assert np.array_equal(dst, np.sort(dst))
@unittest.skipIf(F._default_context_str == 'gpu', reason="Datasets don't need to be tested on GPU.")
def test_reddit():
# RedditDataset
g = data.RedditDataset()[0]
assert g.num_nodes() == 232965
assert g.num_edges() == 114615892
dst = F.asnumpy(g.edges()[1])
assert np.array_equal(dst, np.sort(dst))
if __name__ == '__main__': if __name__ == '__main__':
test_minigc() test_minigc()
test_gin() test_gin()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment