Unverified Commit abf12fc7 authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

[Data] Synthetic dataset for graph classificaiton (#364)

* minigc dataset

* more comments

* sphinx
parent 525bb612
......@@ -26,3 +26,9 @@ For more information about the dataset, see `Sentiment Analysis <https://nlp.sta
.. autoclass:: SST
:members: __getitem__, __len__
Mini graph classification dataset
`````````````````````````````````
.. autoclass:: MiniGC
:members: __getitem__, __len__, num_classes
......@@ -3,6 +3,7 @@ from __future__ import absolute_import
from . import citation_graph as citegrh
from .citation_graph import CoraBinary
from .minigc import *
from .tree import *
from .utils import *
from .sbm import SBMMixture
......
"""A mini synthetic dataset for graph classification benchmark."""
from collections.abc import Sequence
import math
import networkx as nx
import numpy as np
from ..graph import DGLGraph
__all__ = ['MiniGCDataset']
class MiniGCDataset(object):
"""The dataset class.
The datset contains 8 different types of graphs.
- class 0 : cycle graph
- class 1 : star graph
- class 2 : wheel graph
- class 3 : lollipop graph
- class 4 : hypercube graph
- class 5 : grid graph
- class 6 : clique graph
- class 7 : circular ladder graph
"""
def __init__(self, num_graphs, min_num_v, max_num_v):
"""
Parameters
----------
num_graphs: int
Number of graphs in this dataset.
min_num_v: int
Minimum number of nodes for graphs
max_num_v: int
Maximum number of nodes for graphs
"""
super(MiniGCDataset, self).__init__()
self.num_graphs = num_graphs
self.min_num_v = min_num_v
self.max_num_v = max_num_v
self.graphs = []
self.labels = []
self._generate()
def __len__(self):
return len(self.graphs)
def __getitem__(self, idx):
return self.graphs[idx], self.labels[idx]
@property
def num_classes(self):
"""Number of classes."""
return 8
def _generate(self):
self._gen_cycle(self.num_graphs // 8)
self._gen_star(self.num_graphs // 8)
self._gen_wheel(self.num_graphs // 8)
self._gen_lollipop(self.num_graphs // 8)
self._gen_hypercube(self.num_graphs // 8)
self._gen_grid(self.num_graphs // 8)
self._gen_clique(self.num_graphs // 8)
self._gen_circular_ladder(self.num_graphs - len(self.graphs))
# preprocess
for i in range(self.num_graphs):
self.graphs[i] = DGLGraph(self.graphs[i])
# add self edges
nodes = self.graphs[i].nodes()
self.graphs[i].add_edges(nodes, nodes)
def _gen_cycle(self, n):
for _ in range(n):
num_v = np.random.randint(self.min_num_v, self.max_num_v)
g = nx.cycle_graph(num_v)
self.graphs.append(g)
self.labels.append(0)
def _gen_star(self, n):
for _ in range(n):
num_v = np.random.randint(self.min_num_v, self.max_num_v)
# nx.star_graph(N) gives a star graph with N+1 nodes
g = nx.star_graph(num_v - 1)
self.graphs.append(g)
self.labels.append(1)
def _gen_wheel(self, n):
for _ in range(n):
num_v = np.random.randint(self.min_num_v, self.max_num_v)
g = nx.wheel_graph(num_v)
self.graphs.append(g)
self.labels.append(2)
def _gen_lollipop(self, n):
for _ in range(n):
num_v = np.random.randint(self.min_num_v, self.max_num_v)
path_len = np.random.randint(2, num_v // 2)
g = nx.lollipop_graph(m=num_v - path_len, n=path_len)
self.graphs.append(g)
self.labels.append(3)
def _gen_hypercube(self, n):
for _ in range(n):
num_v = np.random.randint(self.min_num_v, self.max_num_v)
g = nx.hypercube_graph(int(math.log(num_v, 2)))
g = nx.convert_node_labels_to_integers(g)
self.graphs.append(g)
self.labels.append(4)
def _gen_grid(self, n):
for _ in range(n):
num_v = np.random.randint(self.min_num_v, self.max_num_v)
assert num_v >= 4, 'We require a grid graph to contain at least two ' \
'rows and two columns, thus 4 nodes, got {:d} ' \
'nodes'.format(num_v)
n_rows = np.random.randint(2, num_v // 2)
n_cols = num_v // n_rows
g = nx.grid_graph([n_rows, n_cols])
g = nx.convert_node_labels_to_integers(g)
self.graphs.append(g)
self.labels.append(5)
def _gen_clique(self, n):
for _ in range(n):
num_v = np.random.randint(self.min_num_v, self.max_num_v)
g = nx.complete_graph(num_v)
self.graphs.append(g)
self.labels.append(6)
def _gen_circular_ladder(self, n):
for _ in range(n):
num_v = np.random.randint(self.min_num_v, self.max_num_v)
g = nx.circular_ladder_graph(num_v)
self.graphs.append(g)
self.labels.append(7)
import dgl.data as data
def test_minigc():
ds = data.MiniGCDataset(16, 10, 20)
g, l = list(zip(*ds))
print(g, l)
if __name__ == '__main__':
test_minigc()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment