Unverified Commit 37aa99c5 authored by Xiangkun Hu's avatar Xiangkun Hu Committed by GitHub
Browse files

[Dataset] GNNBenchmarkDataset (#1912)



* PPIDataset

* Revert "PPIDataset"

This reverts commit 264bd0c960cfa698a7bb946dad132bf52c2d0c8a.

* gnn benchmark dataset

* Update gnn_benckmark.py
Co-authored-by: default avatarxiang song(charlie.song) <classicxsong@gmail.com>
parent 56339b34
......@@ -10,7 +10,8 @@ from .sbm import SBMMixture
from .reddit import RedditDataset
from .ppi import PPIDataset, LegacyPPIDataset
from .tu import TUDataset, LegacyTUDataset
from .gnn_benckmark import AmazonCoBuy, CoraFull, Coauthor
from .gnn_benckmark import AmazonCoBuy, CoraFull, Coauthor, AmazonCoBuyComputerDataset, \
AmazonCoBuyPhotoDataset, CoauthorPhysicsDataset, CoauthorCSDataset, CoraFullDataset
from .karate import KarateClub, KarateClubDataset
from .gindt import GINDataset
from .bitcoinotc import BitcoinOTC, BitcoinOTCDataset
......
"""GNN Benchmark datasets for node classification."""
import scipy.sparse as sp
import numpy as np
import os
from .utils import download, extract_archive, get_download_dir, _get_dgl_url
from ..utils import retry_method_with_fix
from .. import convert
__all__=["AmazonCoBuy", "Coauthor", 'CoraFull']
from .dgl_dataset import DGLBuiltinDataset
from .utils import save_graphs, load_graphs, _get_dgl_url, deprecate_property, deprecate_class
from ..convert import graph as dgl_graph
from .. import backend as F
__all__ = ["AmazonCoBuyComputerDataset", "AmazonCoBuyPhotoDataset", "CoauthorPhysicsDataset", "CoauthorCSDataset",
"CoraFullDataset", "AmazonCoBuy", "Coauthor", "CoraFull"]
def eliminate_self_loops(A):
"""Remove self-loops from the adjacency matrix."""
......@@ -16,24 +21,51 @@ def eliminate_self_loops(A):
return A
class GNNBenchmarkDataset(object):
"""Base Class for GNN Benchmark dataset from https://github.com/shchur/gnn-benchmark#datasets"""
_url = {}
class GNNBenchmarkDataset(DGLBuiltinDataset):
r"""Base Class for GNN Benchmark dataset
Reference: https://github.com/shchur/gnn-benchmark#datasets
"""
def __init__(self, name, raw_dir=None, force_reload=False, verbose=False):
_url = _get_dgl_url('dataset/' + name + '.zip')
super(GNNBenchmarkDataset, self).__init__(name=name,
url=_url,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose)
def __init__(self, name):
assert name.lower() in self._url, "Name not valid"
self.dir = get_download_dir()
self.path = os.path.join(
self.dir, 'gnn_benckmark', self._url[name.lower()].split('/')[-1])
self._name = name
g = self.load_npz(self.path)
self.data = [g]
def process(self):
npz_path = os.path.join(self.raw_path, self.name + '.npz')
g = self._load_npz(npz_path)
self._graph = g
self._data = [g]
self._print_info()
def _download(self):
download(self._url[self._name.lower()], path=self.path)
def has_cache(self):
graph_path = os.path.join(self.save_path, 'dgl_graph.bin')
if os.path.exists(graph_path):
return True
return False
@retry_method_with_fix(_download)
def load_npz(self, file_name):
def save(self):
graph_path = os.path.join(self.save_path, 'dgl_graph.bin')
save_graphs(graph_path, self._graph)
def load(self):
graph_path = os.path.join(self.save_path, 'dgl_graph.bin')
graphs, _ = load_graphs(graph_path)
self._graph = graphs[0]
self._data = [graphs[0]]
self._print_info()
def _print_info(self):
if self.verbose:
print(' NumNodes: {}'.format(self._graph.number_of_nodes()))
print(' NumEdges: {}'.format(self._graph.number_of_edges()))
print(' NumFeats: {}'.format(self._graph.ndata['feat'].shape[-1]))
print(' NumbClasses: {}'.format(self.num_classes))
def _load_npz(self, file_name):
with np.load(file_name, allow_pickle=True) as loader:
loader = dict(loader)
num_nodes = loader['adj_shape'][0]
......@@ -61,68 +93,359 @@ class GNNBenchmarkDataset(object):
labels = None
row = np.hstack([adj_matrix.row, adj_matrix.col])
col = np.hstack([adj_matrix.col, adj_matrix.row])
g = convert.graph((row, col))
g.ndata['feat'] = attr_matrix
g.ndata['label'] = labels
g = dgl_graph((row, col))
g.ndata['feat'] = F.tensor(attr_matrix, F.data_type_dict['float32'])
g.ndata['label'] = F.tensor(labels, F.data_type_dict['int64'])
return g
@property
def num_classes(self):
"""Number of classes."""
raise NotImplementedError
@property
def data(self):
deprecate_property('dataset.data', 'dataset[0]')
return self._data
def __getitem__(self, idx):
r""" Get graph by index
Parameters
----------
idx : int
Item index
Returns
-------
dgl.DGLGraph
graph structure, node features and node labels
- ndata['feat']: node features
- ndata['label']: node labels
"""
assert idx == 0, "This dataset has only one graph"
return self.data[0]
return self._graph
def __len__(self):
return len(self.data)
r"""Number of graphs in the dataset"""
return 1
class CoraFull(GNNBenchmarkDataset):
r"""
class CoraFullDataset(GNNBenchmarkDataset):
r"""CORA-Full dataset for node classification task.
.. deprecated:: 0.5.0
`data` is deprecated, it is repalced by:
>>> dataset = CoraFullDataset()
>>> graph = dataset[0]
Extended Cora dataset from `Deep Gaussian Embedding of Graphs:
Unsupervised Inductive Learning via Ranking`. Nodes represent paper and edges represent citations.
Unsupervised Inductive Learning via Ranking`.
Nodes represent paper and edges represent citations.
Reference: https://github.com/shchur/gnn-benchmark#datasets
Statistics
----------
Nodes: 19,793
Edges: 130,622
Number of Classes: 70
Node feature size: 8,710
Parameters
----------
raw_dir : str
Raw file directory to download/contains the input data directory.
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
Whether to print out progress information. Default: True.
Attributes
----------
num_classes : int
Number of classes for each node.
data : list
A list of DGLGraph objects
Examples
--------
>>> data = CoraFullDataset()
>>> g = data[0]
>>> num_class = data.num_classes
>>> feat = g.ndata['feat'] # get node feature
>>> label = g.ndata['label'] # get node labels
"""
def __init__(self, raw_dir=None, force_reload=False, verbose=False):
super(CoraFullDataset, self).__init__(name="cora_full",
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose)
@property
def num_classes(self):
"""Number of classes."""
return 70
class CoauthorCSDataset(GNNBenchmarkDataset):
r""" 'Computer Science (CS)' part of the Coauthor dataset for node classification task.
.. deprecated:: 0.5.0
`data` is deprecated, it is repalced by:
>>> dataset = CoauthorCSDataset()
>>> graph = dataset[0]
Coauthor CS and Coauthor Physics are co-authorship graphs based on the Microsoft Academic Graph
from the KDD Cup 2016 challenge. Here, nodes are authors, that are connected by an edge if they
co-authored a paper; node features represent paper keywords for each author’s papers, and class
labels indicate most active fields of study for each author.
Reference: https://github.com/shchur/gnn-benchmark#datasets
Statistics
----------
Nodes: 18,333
Edges: 327,576
Number of classes: 15
Node feature size: 6,805
Parameters
----------
raw_dir : str
Raw file directory to download/contains the input data directory.
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
Whether to print out progress information. Default: True.
Attributes
----------
num_classes : int
Number of classes for each node.
data : list
A list of DGLGraph objects
Examples
--------
>>> data = CoauthorCSDataset()
>>> g = data[0]
>>> num_class = data.num_classes
>>> feat = g.ndata['feat'] # get node feature
>>> label = g.ndata['label'] # get node labels
"""
_url = {"cora_full":'https://github.com/shchur/gnn-benchmark/raw/master/data/npz/cora_full.npz'}
def __init__(self, raw_dir=None, force_reload=False, verbose=False):
super(CoauthorCSDataset, self).__init__(name='coauthor_cs',
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose)
@property
def num_classes(self):
"""Number of classes."""
return 15
def __init__(self):
super().__init__("cora_full")
class CoauthorPhysicsDataset(GNNBenchmarkDataset):
r""" 'Physics' part of the Coauthor dataset for node classification task.
.. deprecated:: 0.5.0
`data` is deprecated, it is repalced by:
>>> dataset = CoauthorPhysicsDataset()
>>> graph = dataset[0]
class Coauthor(GNNBenchmarkDataset):
r"""
Coauthor CS and Coauthor Physics are co-authorship graphs based on the Microsoft Academic Graph
from the KDD Cup 2016 challenge 3
. Here, nodes are authors, that are connected by an edge if they
from the KDD Cup 2016 challenge. Here, nodes are authors, that are connected by an edge if they
co-authored a paper; node features represent paper keywords for each author’s papers, and class
labels indicate most active fields of study for each author.
Reference: https://github.com/shchur/gnn-benchmark#datasets
Statistics
----------
Nodes: 34,493
Edges: 991,848
Number of classes: 5
Node feature size: 8,415
Parameters
---------------
name: str
Name of the dataset, has to be 'cs' or 'physics'
----------
raw_dir : str
Raw file directory to download/contains the input data directory.
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
Whether to print out progress information. Default: True.
Attributes
----------
num_classes : int
Number of classes for each node.
data : list
A list of DGLGraph objects
Examples
--------
>>> data = CoauthorPhysicsDataset()
>>> g = data[0]
>>> num_class = data.num_classes
>>> feat = g.ndata['feat'] # get node feature
>>> label = g.ndata['label'] # get node labels
"""
_url = {
'cs': "https://github.com/shchur/gnn-benchmark/raw/master/data/npz/ms_academic_cs.npz",
'physics': "https://github.com/shchur/gnn-benchmark/raw/master/data/npz/ms_academic_phy.npz"
}
def __init__(self, raw_dir=None, force_reload=False, verbose=False):
super(CoauthorPhysicsDataset, self).__init__(name='coauthor_physics',
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose)
@property
def num_classes(self):
"""Number of classes."""
return 5
class AmazonCoBuy(GNNBenchmarkDataset):
r"""
Amazon Computers and Amazon Photo are segments of the Amazon co-purchase graph [McAuley
et al., 2015], where nodes represent goods, edges indicate that two goods are frequently bought
together, node features are bag-of-words encoded product reviews, and class labels are given by the
product category.
class AmazonCoBuyComputerDataset(GNNBenchmarkDataset):
r""" 'Computer' part of the AmazonCoBuy dataset for node classification task.
.. deprecated:: 0.5.0
`data` is deprecated, it is repalced by:
>>> dataset = AmazonCoBuyComputerDataset()
>>> graph = dataset[0]
Amazon Computers and Amazon Photo are segments of the Amazon co-purchase graph [McAuley et al., 2015],
where nodes represent goods, edges indicate that two goods are frequently bought together, node
features are bag-of-words encoded product reviews, and class labels are given by the product category.
Reference: https://github.com/shchur/gnn-benchmark#datasets
Statistics
----------
Nodes: 13,752
Edges: 574,418
Number of classes: 5
Node feature size: 767
Parameters
---------------
name: str
Name of the dataset, has to be 'computers' or 'photo'
----------
raw_dir : str
Raw file directory to download/contains the input data directory.
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
Whether to print out progress information. Default: True.
Attributes
----------
num_classes : int
Number of classes for each node.
data : list
A list of DGLGraph objects
Examples
--------
>>> data = AmazonCoBuyComputerDataset()
>>> g = data[0]
>>> num_class = data.num_classes
>>> feat = g.ndata['feat'] # get node feature
>>> label = g.ndata['label'] # get node labels
"""
_url = {
'computers': "https://github.com/shchur/gnn-benchmark/raw/master/data/npz/amazon_electronics_computers.npz",
'photo': "https://github.com/shchur/gnn-benchmark/raw/master/data/npz/amazon_electronics_photo.npz"
}
def __init__(self, raw_dir=None, force_reload=False, verbose=False):
super(AmazonCoBuyComputerDataset, self).__init__(name='amazon_co_buy_computer',
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose)
@property
def num_classes(self):
"""Number of classes."""
return 5
class AmazonCoBuyPhotoDataset(GNNBenchmarkDataset):
r"""AmazonCoBuy dataset for node classification task.
.. deprecated:: 0.5.0
`data` is deprecated, it is repalced by:
>>> dataset = AmazonCoBuyPhotoDataset()
>>> graph = dataset[0]
Amazon Computers and Amazon Photo are segments of the Amazon co-purchase graph [McAuley et al., 2015],
where nodes represent goods, edges indicate that two goods are frequently bought together, node
features are bag-of-words encoded product reviews, and class labels are given by the product category.
Reference: https://github.com/shchur/gnn-benchmark#datasets
Statistics
----------
Nodes: 7,650
Edges: 287,326
Number of classes: 5
Node feature size: 745
Parameters
----------
raw_dir : str
Raw file directory to download/contains the input data directory.
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
Whether to print out progress information. Default: True.
Attributes
----------
num_classes : int
Number of classes for each node.
data : list
A list of DGLGraph objects
Examples
--------
>>> data = AmazonCoBuyPhotoDataset()
>>> g = data[0]
>>> num_class = data.num_classes
>>> feat = g.ndata['feat'] # get node feature
>>> label = g.ndata['label'] # get node labels
"""
def __init__(self, raw_dir=None, force_reload=False, verbose=False):
super(AmazonCoBuyPhotoDataset, self).__init__(name='amazon_co_buy_photo',
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose)
@property
def num_classes(self):
"""Number of classes."""
return 5
class CoraFull(CoraFullDataset):
def __init__(self, **kwargs):
deprecate_class('CoraFull', 'CoraFullDataset')
super(CoraFull, self).__init__(**kwargs)
def AmazonCoBuy(name):
if name == 'computers':
deprecate_class('AmazonCoBuy', 'AmazonCoBuyComputerDataset')
return AmazonCoBuyComputerDataset()
elif name == 'photo':
deprecate_class('AmazonCoBuy', 'AmazonCoBuyPhotoDataset')
return AmazonCoBuyPhotoDataset()
else:
raise ValueError('Dataset name should be "computers" or "photo".')
def Coauthor(name):
if name == 'cs':
deprecate_class('Coauthor', 'CoauthorCSDataset')
return CoauthorCSDataset()
elif name == 'physics':
deprecate_class('Coauthor', 'CoauthorPhysicsDataset')
return CoauthorPhysicsDataset()
else:
raise ValueError('Dataset name should be "cs" or "physics".')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment