Unverified Commit f05bd497 authored by Tong He's avatar Tong He Committed by GitHub
Browse files

[Dataset] Builtin GINDataset (#1889)



* update GINDataset

* update

* update docstring

* update

* update docstrings

* add hash to fix save/load

* fix some stupid bugs
Co-authored-by: default avatarxiang song(charlie.song) <classicxsong@gmail.com>
parent fb02aa2d
......@@ -55,7 +55,7 @@ class DGLDataset(object):
self._url = url
self._force_reload = force_reload
self._verbose = verbose
self._hash_key = hask_key
self._hash_key = hash_key
self._hash = self._get_hash()
# if no dir is provided, the default dgl download dir is used.
......
"""Dataset for Graph Isomorphism Network(GIN)
(chen jun): Used for compacted graph kernel dataset in GIN
Data sets include:
MUTAG, COLLAB, IMDBBINARY, IMDBMULTI, NCI1, PROTEINS, PTC, REDDITBINARY, REDDITMULTI5K
https://github.com/weihua916/powerful-gnns/blob/master/dataset.zip
"""
......@@ -12,23 +10,19 @@ import numpy as np
from .. import backend as F
from .utils import download, extract_archive, get_download_dir, _get_dgl_url
from .dgl_dataset import DGLBuiltinDataset
from .utils import loadtxt, save_graphs, load_graphs, save_info, load_info, download, extract_archive
from ..utils import retry_method_with_fix
from ..convert import graph
from .. import backend as F
from ..convert import graph as dgl_graph
_url = 'https://raw.githubusercontent.com/weihua916/powerful-gnns/master/dataset.zip'
class GINDataset(object):
class GINDataset(DGLBuiltinDataset):
"""Datasets for Graph Isomorphism Network (GIN)
Adapted from https://github.com/weihua916/powerful-gnns/blob/master/dataset.zip.
The dataset contains the compact format of popular graph kernel datasets, which includes:
MUTAG, COLLAB, IMDBBINARY, IMDBMULTI, NCI1, PROTEINS, PTC, REDDITBINARY, REDDITMULTI5K
This datset class processes all data sets listed above. For more graph kernel datasets,
see :class:`TUDataset`
see :class:`TUDataset`.
Paramters
---------
......@@ -38,23 +32,46 @@ class GINDataset(object):
'IMDBBINARY', 'IMDBMULTI', \
'NCI1', 'PROTEINS', 'PTC', \
'REDDITBINARY', 'REDDITMULTI5K')
self_loop: boolean
self_loop: bool
add self to self edge if true
degree_as_nlabel: boolean
degree_as_nlabel: bool
take node degree as label and feature if true
Examples
--------
>>> data = GINDataset(name='MUTAG', self_loop=False)
**The dataset instance is an iterable**
>>> len(data)
188
>>> g, label = data[128]
>>> g
Graph(num_nodes=13, num_edges=26,
ndata_schemes={'label': Scheme(shape=(), dtype=torch.int64), 'attr': Scheme(shape=(7,), dtype=torch.float64)}
edata_schemes={})
>>> label
tensor(1)
**Batch the graphs and labels for mini-batch training**
>>> graphs, labels = zip(*[data[i] for i in range(16)])
>>> batched_graphs = dgl.batch(graphs)
>>> batched_labels = torch.tensor(labels)
>>> batched_graphs
Graph(num_nodes=330, num_edges=748,
ndata_schemes={'label': Scheme(shape=(), dtype=torch.int64), 'attr': Scheme(shape=(7,), dtype=torch.float64)}
edata_schemes={})
"""
def __init__(self, name, self_loop, degree_as_nlabel=False):
"""Initialize the dataset."""
def __init__(self, name, self_loop, degree_as_nlabel=False,
raw_dir=None, force_reload=False, verbose=False):
self.name = name # MUTAG
self._name = name # MUTAG
gin_url = 'https://raw.githubusercontent.com/weihua916/powerful-gnns/master/dataset.zip'
self.ds_name = 'nig'
self.extract_dir = self._get_extract_dir()
self.file = self._file_path()
self.self_loop = self_loop
self.graphs = []
self.labels = []
......@@ -79,17 +96,27 @@ class GINDataset(object):
self.degree_as_nlabel = degree_as_nlabel
self.nattrs_flag = False
self.nlabels_flag = False
self.verbosity = False
# calc all values
self._load()
super(GINDataset, self).__init__(name=name, url=gin_url, hash_key=(name, self_loop, degree_as_nlabel),
raw_dir=raw_dir, force_reload=force_reload, verbose=verbose)
@property
def raw_path(self):
return os.path.join(self.raw_dir, 'GINDataset')
def download(self):
r""" Automatically download data and extract it.
"""
zip_file_path = os.path.join(self.raw_dir, 'GINDataset.zip')
download(self.url, path=zip_file_path)
extract_archive(zip_file_path, self.raw_path)
def __len__(self):
"""Return the number of graphs in the dataset."""
return len(self.graphs)
def __getitem__(self, idx):
"""Get the i^th sample.
"""Get the idx-th sample.
Paramters
---------
......@@ -98,39 +125,26 @@ class GINDataset(object):
Returns
-------
(dgl.DGLGraph, int)
(dgl.Graph, int)
The graph and its label.
"""
return self.graphs[idx], self.labels[idx]
def _get_extract_dir(self):
return os.path.join(get_download_dir(), "{}".format(self.ds_name))
def _download(self):
download_dir = get_download_dir()
zip_file_path = os.path.join(
download_dir, "{}.zip".format(self.ds_name))
# TODO move to dgl host _get_dgl_url
download(_url, path=zip_file_path)
extract_dir = self._get_extract_dir()
extract_archive(zip_file_path, extract_dir)
def _file_path(self):
return os.path.join(self.extract_dir, "dataset", self.name, "{}.txt".format(self.name))
return os.path.join(self.raw_dir, "GINDataset", 'dataset', self.name, "{}.txt".format(self.name))
@retry_method_with_fix(_download)
def _load(self):
def process(self):
""" Loads input dataset from dataset/NAME/NAME.txt file
"""
if self.verbose:
print('loading data...')
self.file = self._file_path()
with open(self.file, 'r') as f:
# line_1 == N, total number of graphs
self.N = int(f.readline().strip())
for i in range(self.N):
if (i + 1) % 10 == 0 and self.verbosity is True:
if (i + 1) % 10 == 0 and self.verbose is True:
print('processing graph {}...'.format(i + 1))
grow = f.readline().strip().split()
......@@ -145,7 +159,7 @@ class GINDataset(object):
self.labels.append(self.glabel_dict[glabel])
g = graph([])
g = dgl_graph([])
g.add_nodes(n_nodes)
nlabels = [] # node labels
......@@ -179,18 +193,18 @@ class GINDataset(object):
m_edges += nrow[1]
g.add_edges(j, nrow[2:])
if (j + 1) % 10 == 0 and self.verbosity is True:
# add self loop
if self.self_loop:
m_edges += 1
g.add_edges(j, j)
if (j + 1) % 10 == 0 and self.verbose is True:
print(
'processing node {} of graph {}...'.format(
j + 1, i + 1))
print('this node has {} edgs.'.format(
nrow[1]))
# Add self loops
if self.self_loop:
m_edges += n_nodes
g.add_edges(F.arange(0, n_nodes), F.arange(0, n_nodes))
if nattrs != []:
nattrs = np.stack(nattrs)
g.ndata['attr'] = F.tensor(nattrs)
......@@ -198,7 +212,7 @@ class GINDataset(object):
else:
nattrs = None
g.ndata['label'] = F.tensor(np.asarray(nlabels))
g.ndata['label'] = F.tensor(nlabels)
if len(self.nlabel_dict) > 1:
self.nlabels_flag = True
......@@ -210,12 +224,15 @@ class GINDataset(object):
self.graphs.append(g)
self.labels = F.tensor(self.labels)
# if no attr
if not self.nattrs_flag:
if self.verbose:
print('there are no node features in this dataset!')
label2idx = {}
# generate node attr by node degree
if self.degree_as_nlabel:
if self.verbose:
print('generate node features by node degree...')
nlabel_set = set([])
for g in self.graphs:
......@@ -235,6 +252,7 @@ class GINDataset(object):
label2idx = self.ndegree_dict
# generate node attr by node label
else:
if self.verbose:
print('generate node features by node label...')
label2idx = self.nlabel_dict
......@@ -249,6 +267,7 @@ class GINDataset(object):
self.eclasses = len(self.elabel_dict)
self.dim_nfeats = len(self.graphs[0].ndata['attr'][0])
if self.verbose:
print('Done.')
print(
"""
......@@ -269,3 +288,53 @@ class GINDataset(object):
self.dim_nfeats, self.m, self.eclasses,
self.n / self.N, self.m / self.N, self.glabel_dict,
self.nlabel_dict, self.ndegree_dict))
def save(self):
graph_path = os.path.join(self.save_path, 'gin_{}_{}.bin'.format(self.name, self.hash))
info_path = os.path.join(self.save_path, 'gin_{}_{}.pkl'.format(self.name, self.hash))
label_dict = {'labels': self.labels}
info_dict = {'N': self.N,
'n': self.n,
'm': self.m,
'self_loop': self.self_loop,
'gclasses': self.gclasses,
'nclasses': self.nclasses,
'eclasses': self.eclasses,
'dim_nfeats': self.dim_nfeats,
'degree_as_nlabel': self.degree_as_nlabel,
'glabel_dict': self.glabel_dict,
'nlabel_dict': self.nlabel_dict,
'elabel_dict': self.elabel_dict,
'ndegree_dict': self.ndegree_dict}
save_graphs(str(graph_path), self.graphs, label_dict)
save_info(str(info_path), info_dict)
def load(self):
graph_path = os.path.join(self.save_path, 'gin_{}_{}.bin'.format(self.name, self.hash))
info_path = os.path.join(self.save_path, 'gin_{}_{}.pkl'.format(self.name, self.hash))
graphs, label_dict = load_graphs(str(graph_path))
info_dict = load_info(str(info_path))
self.graphs = graphs
self.labels = label_dict['labels']
self.N = info_dict['N']
self.n = info_dict['n']
self.m = info_dict['m']
self.self_loop = info_dict['self_loop']
self.gclasses = info_dict['gclasses']
self.nclasses = info_dict['nclasses']
self.eclasses = info_dict['eclasses']
self.dim_nfeats = info_dict['dim_nfeats']
self.glabel_dict = info_dict['glabel_dict']
self.nlabel_dict = info_dict['nlabel_dict']
self.elabel_dict = info_dict['elabel_dict']
self.ndegree_dict = info_dict['ndegree_dict']
self.degree_as_nlabel = info_dict['degree_as_nlabel']
def has_cache(self):
graph_path = os.path.join(self.save_path, 'gin_{}_{}.bin'.format(self.name, self.hash))
info_path = os.path.join(self.save_path, 'gin_{}_{}.pkl'.format(self.name, self.hash))
if os.path.exists(graph_path) and os.path.exists(info_path):
return True
return False
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment