Unverified Commit 18c960a1 authored by Tong He's avatar Tong He Committed by GitHub
Browse files

[Dataset] Builtin MiniGCDataset (#1890)

* add minigc

* fix

* update docstrings

* update docstring

* add hash to fix save/load

* update hash
parent f05bd497
"""A mini synthetic dataset for graph classification benchmark."""
import math
import math, os
import networkx as nx
import numpy as np
from .. import convert
from .dgl_dataset import DGLDataset
from .utils import save_graphs, load_graphs, makedirs
from .. import backend as F
from ..convert import graph as dgl_graph
from ..transform import add_self_loop
__all__ = ['MiniGCDataset']
class MiniGCDataset(object):
class MiniGCDataset(DGLDataset):
"""The dataset class.
The datset contains 8 different types of graphs.
* class 0 : cycle graph
* class 1 : star graph
* class 2 : wheel graph
......@@ -21,9 +24,6 @@ class MiniGCDataset(object):
* class 6 : clique graph
* class 7 : circular ladder graph
.. note::
This dataset class is compatible with pytorch's :class:`Dataset` class.
Parameters
----------
num_graphs: int
......@@ -32,41 +32,107 @@ class MiniGCDataset(object):
Minimum number of nodes for graphs
max_num_v: int
Maximum number of nodes for graphs
seed : int, default is 0
Random seed for data generation
Attributes
----------
num_graphs : int
Number of graphs
min_num_v : int
The minimum number of nodes
max_num_v : int
The maximum number of nodes
num_classes : int
The number of classes
Examples
--------
>>> data = MiniGCDataset(100, 16, 32, seed=0)
**The dataset instance is an iterable**
>>> len(data)
100
>>> g, label = data[64]
>>> g
Graph(num_nodes=20, num_edges=82,
ndata_schemes={}
edata_schemes={})
>>> label
tensor(5)
**Batch the graphs and labels for mini-batch training**
>>> graphs, labels = zip(*[data[i] for i in range(16)])
>>> batched_graphs = dgl.batch(graphs)
>>> batched_labels = torch.tensor(labels)
>>> batched_graphs
Graph(num_nodes=356, num_edges=1060,
ndata_schemes={}
edata_schemes={})
"""
def __init__(self, num_graphs, min_num_v, max_num_v):
super(MiniGCDataset, self).__init__()
def __init__(self, num_graphs, min_num_v, max_num_v, seed=0,
save_graph=True, force_reload=False, verbose=False):
self.num_graphs = num_graphs
self.min_num_v = min_num_v
self.max_num_v = max_num_v
self.seed = seed
self.save_graph = save_graph
super(MiniGCDataset, self).__init__(name="minigc", hash_key=(num_graphs, min_num_v, max_num_v, seed),
force_reload=force_reload,
verbose=verbose)
def process(self):
self.graphs = []
self.labels = []
self._generate()
self._generate(self.seed)
def __len__(self):
"""Return the number of graphs in the dataset."""
return len(self.graphs)
def __getitem__(self, idx):
"""Get the i^th sample.
"""Get the idx-th sample.
Paramters
---------
idx : int
The sample index.
Returns
-------
(dgl.DGLGraph, int)
(dgl.Graph, int)
The graph and its label.
"""
return self.graphs[idx], self.labels[idx]
def has_cache(self):
graph_path = os.path.join(self.save_path, 'dgl_graph_{}.bin'.format(self.hash))
if os.path.exists(graph_path):
return True
return False
def save(self):
"""save the graph list and the labels"""
if self.save_graph:
graph_path = os.path.join(self.save_path, 'dgl_graph_{}.bin'.format(self.hash))
save_graphs(str(graph_path), self.graphs, {'labels': self.labels})
def load(self):
graphs, label_dict = load_graphs(os.path.join(self.save_path, 'dgl_graph_{}.bin'.format(self.hash)))
self.graphs = graphs
self.labels = label_dict['labels']
@property
def num_classes(self):
"""Number of classes."""
return 8
def _generate(self):
def _generate(self, seed):
if seed is not None:
np.random.seed(seed)
self._gen_cycle(self.num_graphs // 8)
self._gen_star(self.num_graphs // 8)
self._gen_wheel(self.num_graphs // 8)
......@@ -77,10 +143,9 @@ class MiniGCDataset(object):
self._gen_circular_ladder(self.num_graphs - len(self.graphs))
# preprocess
for i in range(self.num_graphs):
self.graphs[i] = convert.graph(self.graphs[i])
# add self edges
nodes = self.graphs[i].nodes()
self.graphs[i].add_edges(nodes, nodes)
# convert to Graph, and add self loops
self.graphs[i] = add_self_loop(dgl_graph(self.graphs[i]))
self.labels = F.tensor(np.array(self.labels).astype(np.int))
def _gen_cycle(self, n):
for _ in range(n):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment