"src/git@developer.sourcefind.cn:one/TransferBench.git" did not exist on "b0e6ccaf57aced24b2ccf444aa09cbcee81ec5e0"
Unverified Commit f05bd497 authored by Tong He's avatar Tong He Committed by GitHub
Browse files

[Dataset] Builtin GINDataset (#1889)



* update GINDataset

* update

* update docstring

* update

* update docstrings

* add hash to fix save/load

* fix some stupid bugs
Co-authored-by: default avatarxiang song(charlie.song) <classicxsong@gmail.com>
parent fb02aa2d
...@@ -55,7 +55,7 @@ class DGLDataset(object): ...@@ -55,7 +55,7 @@ class DGLDataset(object):
self._url = url self._url = url
self._force_reload = force_reload self._force_reload = force_reload
self._verbose = verbose self._verbose = verbose
self._hash_key = hask_key self._hash_key = hash_key
self._hash = self._get_hash() self._hash = self._get_hash()
# if no dir is provided, the default dgl download dir is used. # if no dir is provided, the default dgl download dir is used.
......
"""Dataset for Graph Isomorphism Network(GIN) """Dataset for Graph Isomorphism Network(GIN)
(chen jun): Used for compacted graph kernel dataset in GIN (chen jun): Used for compacted graph kernel dataset in GIN
Data sets include: Data sets include:
MUTAG, COLLAB, IMDBBINARY, IMDBMULTI, NCI1, PROTEINS, PTC, REDDITBINARY, REDDITMULTI5K MUTAG, COLLAB, IMDBBINARY, IMDBMULTI, NCI1, PROTEINS, PTC, REDDITBINARY, REDDITMULTI5K
https://github.com/weihua916/powerful-gnns/blob/master/dataset.zip https://github.com/weihua916/powerful-gnns/blob/master/dataset.zip
""" """
...@@ -12,23 +10,19 @@ import numpy as np ...@@ -12,23 +10,19 @@ import numpy as np
from .. import backend as F from .. import backend as F
from .utils import download, extract_archive, get_download_dir, _get_dgl_url from .dgl_dataset import DGLBuiltinDataset
from .utils import loadtxt, save_graphs, load_graphs, save_info, load_info, download, extract_archive
from ..utils import retry_method_with_fix from ..utils import retry_method_with_fix
from ..convert import graph from ..convert import graph as dgl_graph
from .. import backend as F
_url = 'https://raw.githubusercontent.com/weihua916/powerful-gnns/master/dataset.zip'
class GINDataset(object): class GINDataset(DGLBuiltinDataset):
"""Datasets for Graph Isomorphism Network (GIN) """Datasets for Graph Isomorphism Network (GIN)
Adapted from https://github.com/weihua916/powerful-gnns/blob/master/dataset.zip. Adapted from https://github.com/weihua916/powerful-gnns/blob/master/dataset.zip.
The dataset contains the compact format of popular graph kernel datasets, which includes: The dataset contains the compact format of popular graph kernel datasets, which includes:
MUTAG, COLLAB, IMDBBINARY, IMDBMULTI, NCI1, PROTEINS, PTC, REDDITBINARY, REDDITMULTI5K MUTAG, COLLAB, IMDBBINARY, IMDBMULTI, NCI1, PROTEINS, PTC, REDDITBINARY, REDDITMULTI5K
This datset class processes all data sets listed above. For more graph kernel datasets, This datset class processes all data sets listed above. For more graph kernel datasets,
see :class:`TUDataset` see :class:`TUDataset`.
Paramters Paramters
--------- ---------
...@@ -38,23 +32,46 @@ class GINDataset(object): ...@@ -38,23 +32,46 @@ class GINDataset(object):
'IMDBBINARY', 'IMDBMULTI', \ 'IMDBBINARY', 'IMDBMULTI', \
'NCI1', 'PROTEINS', 'PTC', \ 'NCI1', 'PROTEINS', 'PTC', \
'REDDITBINARY', 'REDDITMULTI5K') 'REDDITBINARY', 'REDDITMULTI5K')
self_loop: boolean self_loop: bool
add self to self edge if true add self to self edge if true
degree_as_nlabel: boolean degree_as_nlabel: bool
take node degree as label and feature if true take node degree as label and feature if true
Examples
--------
>>> data = GINDataset(name='MUTAG', self_loop=False)
**The dataset instance is an iterable**
>>> len(data)
188
>>> g, label = data[128]
>>> g
Graph(num_nodes=13, num_edges=26,
ndata_schemes={'label': Scheme(shape=(), dtype=torch.int64), 'attr': Scheme(shape=(7,), dtype=torch.float64)}
edata_schemes={})
>>> label
tensor(1)
**Batch the graphs and labels for mini-batch training**
>>> graphs, labels = zip(*[data[i] for i in range(16)])
>>> batched_graphs = dgl.batch(graphs)
>>> batched_labels = torch.tensor(labels)
>>> batched_graphs
Graph(num_nodes=330, num_edges=748,
ndata_schemes={'label': Scheme(shape=(), dtype=torch.int64), 'attr': Scheme(shape=(7,), dtype=torch.float64)}
edata_schemes={})
""" """
def __init__(self, name, self_loop, degree_as_nlabel=False): def __init__(self, name, self_loop, degree_as_nlabel=False,
"""Initialize the dataset.""" raw_dir=None, force_reload=False, verbose=False):
self.name = name # MUTAG self._name = name # MUTAG
gin_url = 'https://raw.githubusercontent.com/weihua916/powerful-gnns/master/dataset.zip'
self.ds_name = 'nig' self.ds_name = 'nig'
self.extract_dir = self._get_extract_dir()
self.file = self._file_path()
self.self_loop = self_loop self.self_loop = self_loop
self.graphs = [] self.graphs = []
self.labels = [] self.labels = []
...@@ -79,17 +96,27 @@ class GINDataset(object): ...@@ -79,17 +96,27 @@ class GINDataset(object):
self.degree_as_nlabel = degree_as_nlabel self.degree_as_nlabel = degree_as_nlabel
self.nattrs_flag = False self.nattrs_flag = False
self.nlabels_flag = False self.nlabels_flag = False
self.verbosity = False
# calc all values super(GINDataset, self).__init__(name=name, url=gin_url, hash_key=(name, self_loop, degree_as_nlabel),
self._load() raw_dir=raw_dir, force_reload=force_reload, verbose=verbose)
@property
def raw_path(self):
return os.path.join(self.raw_dir, 'GINDataset')
def download(self):
r""" Automatically download data and extract it.
"""
zip_file_path = os.path.join(self.raw_dir, 'GINDataset.zip')
download(self.url, path=zip_file_path)
extract_archive(zip_file_path, self.raw_path)
def __len__(self): def __len__(self):
"""Return the number of graphs in the dataset.""" """Return the number of graphs in the dataset."""
return len(self.graphs) return len(self.graphs)
def __getitem__(self, idx): def __getitem__(self, idx):
"""Get the i^th sample. """Get the idx-th sample.
Paramters Paramters
--------- ---------
...@@ -98,39 +125,26 @@ class GINDataset(object): ...@@ -98,39 +125,26 @@ class GINDataset(object):
Returns Returns
------- -------
(dgl.DGLGraph, int) (dgl.Graph, int)
The graph and its label. The graph and its label.
""" """
return self.graphs[idx], self.labels[idx] return self.graphs[idx], self.labels[idx]
def _get_extract_dir(self):
return os.path.join(get_download_dir(), "{}".format(self.ds_name))
def _download(self):
download_dir = get_download_dir()
zip_file_path = os.path.join(
download_dir, "{}.zip".format(self.ds_name))
# TODO move to dgl host _get_dgl_url
download(_url, path=zip_file_path)
extract_dir = self._get_extract_dir()
extract_archive(zip_file_path, extract_dir)
def _file_path(self): def _file_path(self):
return os.path.join(self.extract_dir, "dataset", self.name, "{}.txt".format(self.name)) return os.path.join(self.raw_dir, "GINDataset", 'dataset', self.name, "{}.txt".format(self.name))
@retry_method_with_fix(_download) def process(self):
def _load(self):
""" Loads input dataset from dataset/NAME/NAME.txt file """ Loads input dataset from dataset/NAME/NAME.txt file
""" """
if self.verbose:
print('loading data...') print('loading data...')
self.file = self._file_path()
with open(self.file, 'r') as f: with open(self.file, 'r') as f:
# line_1 == N, total number of graphs # line_1 == N, total number of graphs
self.N = int(f.readline().strip()) self.N = int(f.readline().strip())
for i in range(self.N): for i in range(self.N):
if (i + 1) % 10 == 0 and self.verbosity is True: if (i + 1) % 10 == 0 and self.verbose is True:
print('processing graph {}...'.format(i + 1)) print('processing graph {}...'.format(i + 1))
grow = f.readline().strip().split() grow = f.readline().strip().split()
...@@ -145,7 +159,7 @@ class GINDataset(object): ...@@ -145,7 +159,7 @@ class GINDataset(object):
self.labels.append(self.glabel_dict[glabel]) self.labels.append(self.glabel_dict[glabel])
g = graph([]) g = dgl_graph([])
g.add_nodes(n_nodes) g.add_nodes(n_nodes)
nlabels = [] # node labels nlabels = [] # node labels
...@@ -179,17 +193,17 @@ class GINDataset(object): ...@@ -179,17 +193,17 @@ class GINDataset(object):
m_edges += nrow[1] m_edges += nrow[1]
g.add_edges(j, nrow[2:]) g.add_edges(j, nrow[2:])
if (j + 1) % 10 == 0 and self.verbosity is True: # add self loop
if self.self_loop:
m_edges += 1
g.add_edges(j, j)
if (j + 1) % 10 == 0 and self.verbose is True:
print( print(
'processing node {} of graph {}...'.format( 'processing node {} of graph {}...'.format(
j + 1, i + 1)) j + 1, i + 1))
print('this node has {} edgs.'.format( print('this node has {} edgs.'.format(
nrow[1])) nrow[1]))
# Add self loops
if self.self_loop:
m_edges += n_nodes
g.add_edges(F.arange(0, n_nodes), F.arange(0, n_nodes))
if nattrs != []: if nattrs != []:
nattrs = np.stack(nattrs) nattrs = np.stack(nattrs)
...@@ -198,7 +212,7 @@ class GINDataset(object): ...@@ -198,7 +212,7 @@ class GINDataset(object):
else: else:
nattrs = None nattrs = None
g.ndata['label'] = F.tensor(np.asarray(nlabels)) g.ndata['label'] = F.tensor(nlabels)
if len(self.nlabel_dict) > 1: if len(self.nlabel_dict) > 1:
self.nlabels_flag = True self.nlabels_flag = True
...@@ -210,13 +224,16 @@ class GINDataset(object): ...@@ -210,13 +224,16 @@ class GINDataset(object):
self.graphs.append(g) self.graphs.append(g)
self.labels = F.tensor(self.labels)
# if no attr # if no attr
if not self.nattrs_flag: if not self.nattrs_flag:
print('there are no node features in this dataset!') if self.verbose:
print('there are no node features in this dataset!')
label2idx = {} label2idx = {}
# generate node attr by node degree # generate node attr by node degree
if self.degree_as_nlabel: if self.degree_as_nlabel:
print('generate node features by node degree...') if self.verbose:
print('generate node features by node degree...')
nlabel_set = set([]) nlabel_set = set([])
for g in self.graphs: for g in self.graphs:
# actually this label shouldn't be updated # actually this label shouldn't be updated
...@@ -235,7 +252,8 @@ class GINDataset(object): ...@@ -235,7 +252,8 @@ class GINDataset(object):
label2idx = self.ndegree_dict label2idx = self.ndegree_dict
# generate node attr by node label # generate node attr by node label
else: else:
print('generate node features by node label...') if self.verbose:
print('generate node features by node label...')
label2idx = self.nlabel_dict label2idx = self.nlabel_dict
for g in self.graphs: for g in self.graphs:
...@@ -249,23 +267,74 @@ class GINDataset(object): ...@@ -249,23 +267,74 @@ class GINDataset(object):
self.eclasses = len(self.elabel_dict) self.eclasses = len(self.elabel_dict)
self.dim_nfeats = len(self.graphs[0].ndata['attr'][0]) self.dim_nfeats = len(self.graphs[0].ndata['attr'][0])
print('Done.') if self.verbose:
print( print('Done.')
""" print(
-------- Data Statistics --------' """
#Graphs: %d -------- Data Statistics --------'
#Graph Classes: %d #Graphs: %d
#Nodes: %d #Graph Classes: %d
#Node Classes: %d #Nodes: %d
#Node Features Dim: %d #Node Classes: %d
#Edges: %d #Node Features Dim: %d
#Edge Classes: %d #Edges: %d
Avg. of #Nodes: %.2f #Edge Classes: %d
Avg. of #Edges: %.2f Avg. of #Nodes: %.2f
Graph Relabeled: %s Avg. of #Edges: %.2f
Node Relabeled: %s Graph Relabeled: %s
Degree Relabeled(If degree_as_nlabel=True): %s \n """ % ( Node Relabeled: %s
self.N, self.gclasses, self.n, self.nclasses, Degree Relabeled(If degree_as_nlabel=True): %s \n """ % (
self.dim_nfeats, self.m, self.eclasses, self.N, self.gclasses, self.n, self.nclasses,
self.n / self.N, self.m / self.N, self.glabel_dict, self.dim_nfeats, self.m, self.eclasses,
self.nlabel_dict, self.ndegree_dict)) self.n / self.N, self.m / self.N, self.glabel_dict,
self.nlabel_dict, self.ndegree_dict))
def save(self):
graph_path = os.path.join(self.save_path, 'gin_{}_{}.bin'.format(self.name, self.hash))
info_path = os.path.join(self.save_path, 'gin_{}_{}.pkl'.format(self.name, self.hash))
label_dict = {'labels': self.labels}
info_dict = {'N': self.N,
'n': self.n,
'm': self.m,
'self_loop': self.self_loop,
'gclasses': self.gclasses,
'nclasses': self.nclasses,
'eclasses': self.eclasses,
'dim_nfeats': self.dim_nfeats,
'degree_as_nlabel': self.degree_as_nlabel,
'glabel_dict': self.glabel_dict,
'nlabel_dict': self.nlabel_dict,
'elabel_dict': self.elabel_dict,
'ndegree_dict': self.ndegree_dict}
save_graphs(str(graph_path), self.graphs, label_dict)
save_info(str(info_path), info_dict)
def load(self):
graph_path = os.path.join(self.save_path, 'gin_{}_{}.bin'.format(self.name, self.hash))
info_path = os.path.join(self.save_path, 'gin_{}_{}.pkl'.format(self.name, self.hash))
graphs, label_dict = load_graphs(str(graph_path))
info_dict = load_info(str(info_path))
self.graphs = graphs
self.labels = label_dict['labels']
self.N = info_dict['N']
self.n = info_dict['n']
self.m = info_dict['m']
self.self_loop = info_dict['self_loop']
self.gclasses = info_dict['gclasses']
self.nclasses = info_dict['nclasses']
self.eclasses = info_dict['eclasses']
self.dim_nfeats = info_dict['dim_nfeats']
self.glabel_dict = info_dict['glabel_dict']
self.nlabel_dict = info_dict['nlabel_dict']
self.elabel_dict = info_dict['elabel_dict']
self.ndegree_dict = info_dict['ndegree_dict']
self.degree_as_nlabel = info_dict['degree_as_nlabel']
def has_cache(self):
graph_path = os.path.join(self.save_path, 'gin_{}_{}.bin'.format(self.name, self.hash))
info_path = os.path.join(self.save_path, 'gin_{}_{}.pkl'.format(self.name, self.hash))
if os.path.exists(graph_path) and os.path.exists(info_path):
return True
return False
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment