Unverified Commit 451ed6d8 authored by Xiangkun Hu's avatar Xiangkun Hu Committed by GitHub
Browse files

[Dataset] SBMMixtureDataset (#1920)



* PPIDataset

* Revert "PPIDataset"

This reverts commit 264bd0c960cfa698a7bb946dad132bf52c2d0c8a.

* SBMMixture dataset

* Update sbm.py

* fix example

* Update sbm.py

* Revert "Update sbm.py"

This reverts commit 066db5c89bd5e159981ae7cad1bfb883ea5db71d.

* Update sbm.py
Co-authored-by: default avatarxiang song(charlie.song) <classicxsong@gmail.com>
parent 4a8cc489
...@@ -17,7 +17,7 @@ import torch.nn.functional as F ...@@ -17,7 +17,7 @@ import torch.nn.functional as F
import torch.optim as optim import torch.optim as optim
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from dgl.data import SBMMixture from dgl.data import SBMMixtureDataset
import gnn import gnn
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
...@@ -38,7 +38,7 @@ args = parser.parse_args() ...@@ -38,7 +38,7 @@ args = parser.parse_args()
dev = th.device('cpu') if args.gpu < 0 else th.device('cuda:%d' % args.gpu) dev = th.device('cpu') if args.gpu < 0 else th.device('cuda:%d' % args.gpu)
K = args.n_communities K = args.n_communities
training_dataset = SBMMixture(args.n_graphs, args.n_nodes, K) training_dataset = SBMMixtureDataset(args.n_graphs, args.n_nodes, K)
training_loader = DataLoader(training_dataset, args.batch_size, training_loader = DataLoader(training_dataset, args.batch_size,
collate_fn=training_dataset.collate_fn, drop_last=True) collate_fn=training_dataset.collate_fn, drop_last=True)
...@@ -105,7 +105,7 @@ def test(): ...@@ -105,7 +105,7 @@ def test():
N = 1 N = 1
overlap_list = [] overlap_list = []
for p, q in zip(p_list, q_list): for p, q in zip(p_list, q_list):
dataset = SBMMixture(N, args.n_nodes, K, pq=[[p, q]] * N) dataset = SBMMixtureDataset(N, args.n_nodes, K, pq=[[p, q]] * N)
loader = DataLoader(dataset, N, collate_fn=dataset.collate_fn) loader = DataLoader(dataset, N, collate_fn=dataset.collate_fn)
g, lg, deg_g, deg_lg, pm_pd = next(iter(loader)) g, lg, deg_g, deg_lg, pm_pd = next(iter(loader))
z = inference(g, lg, deg_g, deg_lg, pm_pd) z = inference(g, lg, deg_g, deg_lg, pm_pd)
......
...@@ -6,7 +6,7 @@ from .citation_graph import CoraBinary, CitationGraphDataset ...@@ -6,7 +6,7 @@ from .citation_graph import CoraBinary, CitationGraphDataset
from .minigc import * from .minigc import *
from .tree import SST, SSTDataset from .tree import SST, SSTDataset
from .utils import * from .utils import *
from .sbm import SBMMixture from .sbm import SBMMixture, SBMMixtureDataset
from .reddit import RedditDataset from .reddit import RedditDataset
from .ppi import PPIDataset, LegacyPPIDataset from .ppi import PPIDataset, LegacyPPIDataset
from .tu import TUDataset, LegacyTUDataset from .tu import TUDataset, LegacyTUDataset
......
"""Dataset for stochastic block model.""" """Dataset for stochastic block model."""
import math import math
import random import random
import os
import numpy as np import numpy as np
import numpy.random as npr import numpy.random as npr
import scipy as sp import scipy as sp
from .. import convert from .dgl_dataset import DGLDataset
from ..convert import graph as dgl_graph
from .. import batch from .. import batch
from .utils import save_info, save_graphs, load_info, load_graphs
def sbm(n_blocks, block_size, p, q, rng=None): def sbm(n_blocks, block_size, p, q, rng=None):
""" (Symmetric) Stochastic Block Model """ (Symmetric) Stochastic Block Model
...@@ -22,6 +26,8 @@ def sbm(n_blocks, block_size, p, q, rng=None): ...@@ -22,6 +26,8 @@ def sbm(n_blocks, block_size, p, q, rng=None):
Probability for intra-community edge. Probability for intra-community edge.
q : float q : float
Probability for inter-community edge. Probability for inter-community edge.
rng : numpy.random.RandomState, optional
Random number generator.
Returns Returns
------- -------
...@@ -49,9 +55,12 @@ def sbm(n_blocks, block_size, p, q, rng=None): ...@@ -49,9 +55,12 @@ def sbm(n_blocks, block_size, p, q, rng=None):
adj = sp.sparse.triu(a) + sp.sparse.triu(a, 1).transpose() adj = sp.sparse.triu(a) + sp.sparse.triu(a, 1).transpose()
return adj return adj
class SBMMixture:
""" Symmetric Stochastic Block Model Mixture class SBMMixtureDataset(DGLDataset):
Please refer to Appendix C of "Supervised Community Detection with Hierarchical Graph Neural Networks" (https://arxiv.org/abs/1705.08415) for details. r""" Symmetric Stochastic Block Model Mixture
Reference: Appendix C of "Supervised Community Detection with Hierarchical
Graph Neural Networks" (https://arxiv.org/abs/1705.08415).
Parameters Parameters
---------- ----------
...@@ -62,40 +71,123 @@ class SBMMixture: ...@@ -62,40 +71,123 @@ class SBMMixture:
n_communities : int n_communities : int
Number of communities. Number of communities.
k : int, optional k : int, optional
Multiplier. Multiplier. Default: 2
avg_deg : int, optional avg_deg : int, optional
Average degree. Average degree. Default: 3
pq : list of pair of nonnegative float or str, optional pq : list of pair of nonnegative float or str, optional
Random densities. Random densities. This parameter is for future extension,
for now it's always using the default value.
Default: Appendix_C
rng : numpy.random.RandomState, optional rng : numpy.random.RandomState, optional
Random number generator. Random number generator. If not given, it's numpy.random.RandomState() with `seed=None`,
which read data from /dev/urandom (or the Windows analogue) if available or seed from
the clock otherwise.
Default: None
Raises
------
RuntimeError is raised if pq is not a list or string.
Examples
--------
>>> data = SBMMixtureDataset(n_graphs=16, n_nodes=10000, n_communities=2)
>>> from torch.utils.data import DataLoader
>>> dataloader = DataLoader(data, batch_size=1, collate_fn=data.collate_fn)
>>> for graph, line_graph, graph_degrees, line_graph_degrees, pm_pd in dataloader:
... # your code here
""" """
def __init__(self, n_graphs, n_nodes, n_communities, def __init__(self,
k=2, avg_deg=3, pq='Appendix C', rng=None): n_graphs,
n_nodes,
n_communities,
k=2,
avg_deg=3,
pq='Appendix_C',
rng=None):
self._n_graphs = n_graphs
self._n_nodes = n_nodes self._n_nodes = n_nodes
self._n_communities = n_communities
assert n_nodes % n_communities == 0 assert n_nodes % n_communities == 0
block_size = n_nodes // n_communities self._block_size = n_nodes // n_communities
self._k = k self._k = k
self._avg_deg = avg_deg self._avg_deg = avg_deg
self._pq = pq
self._rng = rng
super(SBMMixtureDataset, self).__init__(name='sbmmixture',
hash_key=(n_graphs, n_nodes, n_communities, k, avg_deg, pq, rng))
def process(self):
pq = self._pq
if type(pq) is list: if type(pq) is list:
assert len(pq) == n_graphs assert len(pq) == self._n_graphs
elif type(pq) is str: elif type(pq) is str:
generator = {'Appendix C' : self._appendix_c}[pq] generator = {'Appendix_C': self._appendix_c}[pq]
pq = [generator() for i in range(n_graphs)] pq = [generator() for _ in range(self._n_graphs)]
else: else:
raise RuntimeError() raise RuntimeError()
self._gs = [convert.graph(sbm(n_communities, block_size, *x)) for x in pq] self._graphs = [dgl_graph(sbm(self._n_communities, self._block_size, *x)) for x in pq]
self._lgs = [g.line_graph(backtracking=False) for g in self._gs] self._line_graphs = [g.line_graph(backtracking=False) for g in self._graphs]
self._g_degs = [g.in_degrees().float() for g in self._gs] in_degrees = lambda g: g.in_degrees().float()
self._lg_degs = [lg.in_degrees().float() for lg in self._lgs] self._graph_degrees = [in_degrees(g) for g in self._graphs]
self._pm_pds = list(zip(*[g.edges() for g in self._gs]))[0] self._line_graph_degrees = [in_degrees(lg) for lg in self._line_graphs]
self._pm_pds = list(zip(*[g.edges() for g in self._graphs]))[0]
def has_cache(self):
graph_path = os.path.join(self.save_path, 'graphs_{}.bin'.format(self.hash))
line_graph_path = os.path.join(self.save_path, 'line_graphs_{}.bin'.format(self.hash))
info_path = os.path.join(self.save_path, 'info_{}.pkl'.format(self.hash))
return os.path.exists(graph_path) and \
os.path.exists(line_graph_path) and \
os.path.exists(info_path)
def save(self):
graph_path = os.path.join(self.save_path, 'graphs_{}.bin'.format(self.hash))
line_graph_path = os.path.join(self.save_path, 'line_graphs_{}.bin'.format(self.hash))
info_path = os.path.join(self.save_path, 'info_{}.pkl'.format(self.hash))
save_graphs(graph_path, self._graphs)
save_graphs(line_graph_path, self._line_graphs)
save_info(info_path, {'graph_degree': self._graph_degrees,
'line_graph_degree': self._line_graph_degrees,
'pm_pds': self._pm_pds})
def load(self):
graph_path = os.path.join(self.save_path, 'graphs_{}.bin'.format(self.hash))
line_graph_path = os.path.join(self.save_path, 'line_graphs_{}.bin'.format(self.hash))
info_path = os.path.join(self.save_path, 'info_{}.pkl'.format(self.hash))
self._graphs, _ = load_graphs(graph_path)
self._line_graphs, _ = load_graphs(line_graph_path)
info = load_info(info_path)
self._graph_degrees = info['graph_degree']
self._line_graph_degrees = info['line_graph_degree']
self._pm_pds = info['pm_pds']
def __len__(self): def __len__(self):
return len(self._gs) r"""Number of graphs in the dataset."""
return len(self._graphs)
def __getitem__(self, idx): def __getitem__(self, idx):
return self._gs[idx], self._lgs[idx], \ r""" Get one example by index
self._g_degs[idx], self._lg_degs[idx], self._pm_pds[idx]
Parameters
----------
idx : int
Item index
Returns
-------
graph : dgl.DGLGraph
The original graph
line_graph : dgl.DGLGraph
The line graph of `graph`
graph_degree : numpy.ndarray
In degrees for each node in `graph`
line_graph_degree : numpy.ndarray
In degrees for each node in `line_graph`
pm_pd : numpy.ndarray
Edge indicator matrices Pm and Pd
"""
return self._graphs[idx], self._line_graphs[idx], \
self._graph_degrees[idx], self._line_graph_degrees[idx], self._pm_pds[idx]
def _appendix_c(self): def _appendix_c(self):
q = npr.uniform(0, self._avg_deg - math.sqrt(self._avg_deg)) q = npr.uniform(0, self._avg_deg - math.sqrt(self._avg_deg))
...@@ -106,6 +198,36 @@ class SBMMixture: ...@@ -106,6 +198,36 @@ class SBMMixture:
return q, p return q, p
def collate_fn(self, x): def collate_fn(self, x):
r""" The `collate` function for dataloader
Parameters
----------
x : tuple
a batch of data that contains
graph : dgl.DGLGraph
The original graph
line_graph : dgl.DGLGraph
The line graph of `graph`
graph_degree : numpy.ndarray
In degrees for each node in `graph`
line_graph_degree : numpy.ndarray
In degrees for each node in `line_graph`
pm_pd : numpy.ndarray
Edge indicator matrices Pm and Pd
Returns
-------
g_batch : dgl.DGLGraph
Batched graphs
lg_batch : dgl.DGLGraph
Batched line graphs
degg_batch : numpy.ndarray
A batch of in degrees for each node in `g_batch`
deglg_batch : numpy.ndarray
A batch of in degrees for each node in `lg_batch`
pm_pd_batch : numpy.ndarray
A batch of edge indicator matrices Pm and Pd
"""
g, lg, deg_g, deg_lg, pm_pd = zip(*x) g, lg, deg_g, deg_lg, pm_pd = zip(*x)
g_batch = batch.batch(g) g_batch = batch.batch(g)
lg_batch = batch.batch(lg) lg_batch = batch.batch(lg)
...@@ -113,3 +235,6 @@ class SBMMixture: ...@@ -113,3 +235,6 @@ class SBMMixture:
deglg_batch = np.concatenate(deg_lg, axis=0) deglg_batch = np.concatenate(deg_lg, axis=0)
pm_pd_batch = np.concatenate([x + i * self._n_nodes for i, x in enumerate(pm_pd)], axis=0) pm_pd_batch = np.concatenate([x + i * self._n_nodes for i, x in enumerate(pm_pd)], axis=0)
return g_batch, lg_batch, degg_batch, deglg_batch, pm_pd_batch return g_batch, lg_batch, degg_batch, deglg_batch, pm_pd_batch
SBMMixture = SBMMixtureDataset
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment