"docs/source/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "a36d0841b59af1003621e3b72560ca06e33fc23f"
Unverified Commit 18c960a1 authored by Tong He's avatar Tong He Committed by GitHub
Browse files

[Dataset] Builtin MiniGCDataset (#1890)

* add minigc

* fix

* update docstrings

* update docstring

* add hash to fix save/load

* update hash
parent f05bd497
"""A mini synthetic dataset for graph classification benchmark.""" """A mini synthetic dataset for graph classification benchmark."""
import math import math, os
import networkx as nx import networkx as nx
import numpy as np import numpy as np
from .. import convert from .dgl_dataset import DGLDataset
from .utils import save_graphs, load_graphs, makedirs
from .. import backend as F
from ..convert import graph as dgl_graph
from ..transform import add_self_loop
__all__ = ['MiniGCDataset'] __all__ = ['MiniGCDataset']
class MiniGCDataset(object): class MiniGCDataset(DGLDataset):
"""The dataset class. """The dataset class.
The datset contains 8 different types of graphs. The datset contains 8 different types of graphs.
* class 0 : cycle graph * class 0 : cycle graph
* class 1 : star graph * class 1 : star graph
* class 2 : wheel graph * class 2 : wheel graph
...@@ -21,9 +24,6 @@ class MiniGCDataset(object): ...@@ -21,9 +24,6 @@ class MiniGCDataset(object):
* class 6 : clique graph * class 6 : clique graph
* class 7 : circular ladder graph * class 7 : circular ladder graph
.. note::
This dataset class is compatible with pytorch's :class:`Dataset` class.
Parameters Parameters
---------- ----------
num_graphs: int num_graphs: int
...@@ -32,41 +32,107 @@ class MiniGCDataset(object): ...@@ -32,41 +32,107 @@ class MiniGCDataset(object):
Minimum number of nodes for graphs Minimum number of nodes for graphs
max_num_v: int max_num_v: int
Maximum number of nodes for graphs Maximum number of nodes for graphs
seed : int, default is 0
Random seed for data generation
Attributes
----------
num_graphs : int
Number of graphs
min_num_v : int
The minimum number of nodes
max_num_v : int
The maximum number of nodes
num_classes : int
The number of classes
Examples
--------
>>> data = MiniGCDataset(100, 16, 32, seed=0)
**The dataset instance is an iterable**
>>> len(data)
100
>>> g, label = data[64]
>>> g
Graph(num_nodes=20, num_edges=82,
ndata_schemes={}
edata_schemes={})
>>> label
tensor(5)
**Batch the graphs and labels for mini-batch training**
>>> graphs, labels = zip(*[data[i] for i in range(16)])
>>> batched_graphs = dgl.batch(graphs)
>>> batched_labels = torch.tensor(labels)
>>> batched_graphs
Graph(num_nodes=356, num_edges=1060,
ndata_schemes={}
edata_schemes={})
""" """
def __init__(self, num_graphs, min_num_v, max_num_v):
super(MiniGCDataset, self).__init__() def __init__(self, num_graphs, min_num_v, max_num_v, seed=0,
save_graph=True, force_reload=False, verbose=False):
self.num_graphs = num_graphs self.num_graphs = num_graphs
self.min_num_v = min_num_v self.min_num_v = min_num_v
self.max_num_v = max_num_v self.max_num_v = max_num_v
self.seed = seed
self.save_graph = save_graph
super(MiniGCDataset, self).__init__(name="minigc", hash_key=(num_graphs, min_num_v, max_num_v, seed),
force_reload=force_reload,
verbose=verbose)
def process(self):
self.graphs = [] self.graphs = []
self.labels = [] self.labels = []
self._generate() self._generate(self.seed)
def __len__(self): def __len__(self):
"""Return the number of graphs in the dataset.""" """Return the number of graphs in the dataset."""
return len(self.graphs) return len(self.graphs)
def __getitem__(self, idx): def __getitem__(self, idx):
"""Get the i^th sample. """Get the idx-th sample.
Paramters Paramters
--------- ---------
idx : int idx : int
The sample index. The sample index.
Returns Returns
------- -------
(dgl.DGLGraph, int) (dgl.Graph, int)
The graph and its label. The graph and its label.
""" """
return self.graphs[idx], self.labels[idx] return self.graphs[idx], self.labels[idx]
def has_cache(self):
graph_path = os.path.join(self.save_path, 'dgl_graph_{}.bin'.format(self.hash))
if os.path.exists(graph_path):
return True
return False
def save(self):
"""save the graph list and the labels"""
if self.save_graph:
graph_path = os.path.join(self.save_path, 'dgl_graph_{}.bin'.format(self.hash))
save_graphs(str(graph_path), self.graphs, {'labels': self.labels})
def load(self):
graphs, label_dict = load_graphs(os.path.join(self.save_path, 'dgl_graph_{}.bin'.format(self.hash)))
self.graphs = graphs
self.labels = label_dict['labels']
@property @property
def num_classes(self): def num_classes(self):
"""Number of classes.""" """Number of classes."""
return 8 return 8
def _generate(self): def _generate(self, seed):
if seed is not None:
np.random.seed(seed)
self._gen_cycle(self.num_graphs // 8) self._gen_cycle(self.num_graphs // 8)
self._gen_star(self.num_graphs // 8) self._gen_star(self.num_graphs // 8)
self._gen_wheel(self.num_graphs // 8) self._gen_wheel(self.num_graphs // 8)
...@@ -77,10 +143,9 @@ class MiniGCDataset(object): ...@@ -77,10 +143,9 @@ class MiniGCDataset(object):
self._gen_circular_ladder(self.num_graphs - len(self.graphs)) self._gen_circular_ladder(self.num_graphs - len(self.graphs))
# preprocess # preprocess
for i in range(self.num_graphs): for i in range(self.num_graphs):
self.graphs[i] = convert.graph(self.graphs[i]) # convert to Graph, and add self loops
# add self edges self.graphs[i] = add_self_loop(dgl_graph(self.graphs[i]))
nodes = self.graphs[i].nodes() self.labels = F.tensor(np.array(self.labels).astype(np.int))
self.graphs[i].add_edges(nodes, nodes)
def _gen_cycle(self, n): def _gen_cycle(self, n):
for _ in range(n): for _ in range(n):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment