"src/vscode:/vscode.git/clone" did not exist on "d0c02398b986f2876b2b79f3a137ed00a7edde35"
Unverified Commit 451ed6d8 authored by Xiangkun Hu's avatar Xiangkun Hu Committed by GitHub
Browse files

[Dataset] SBMMixtureDataset (#1920)



* PPIDataset

* Revert "PPIDataset"

This reverts commit 264bd0c960cfa698a7bb946dad132bf52c2d0c8a.

* SBMMixture dataset

* Update sbm.py

* fix example

* Update sbm.py

* Revert "Update sbm.py"

This reverts commit 066db5c89bd5e159981ae7cad1bfb883ea5db71d.

* Update sbm.py
Co-authored-by: default avatarxiang song(charlie.song) <classicxsong@gmail.com>
parent 4a8cc489
......@@ -17,7 +17,7 @@ import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from dgl.data import SBMMixture
from dgl.data import SBMMixtureDataset
import gnn
parser = argparse.ArgumentParser()
......@@ -38,7 +38,7 @@ args = parser.parse_args()
dev = th.device('cpu') if args.gpu < 0 else th.device('cuda:%d' % args.gpu)
K = args.n_communities
training_dataset = SBMMixture(args.n_graphs, args.n_nodes, K)
training_dataset = SBMMixtureDataset(args.n_graphs, args.n_nodes, K)
training_loader = DataLoader(training_dataset, args.batch_size,
collate_fn=training_dataset.collate_fn, drop_last=True)
......@@ -105,7 +105,7 @@ def test():
N = 1
overlap_list = []
for p, q in zip(p_list, q_list):
dataset = SBMMixture(N, args.n_nodes, K, pq=[[p, q]] * N)
dataset = SBMMixtureDataset(N, args.n_nodes, K, pq=[[p, q]] * N)
loader = DataLoader(dataset, N, collate_fn=dataset.collate_fn)
g, lg, deg_g, deg_lg, pm_pd = next(iter(loader))
z = inference(g, lg, deg_g, deg_lg, pm_pd)
......
......@@ -6,7 +6,7 @@ from .citation_graph import CoraBinary, CitationGraphDataset
from .minigc import *
from .tree import SST, SSTDataset
from .utils import *
from .sbm import SBMMixture
from .sbm import SBMMixture, SBMMixtureDataset
from .reddit import RedditDataset
from .ppi import PPIDataset, LegacyPPIDataset
from .tu import TUDataset, LegacyTUDataset
......
"""Dataset for stochastic block model."""
import math
import random
import os
import numpy as np
import numpy.random as npr
import scipy as sp
from .. import convert
from .dgl_dataset import DGLDataset
from ..convert import graph as dgl_graph
from .. import batch
from .utils import save_info, save_graphs, load_info, load_graphs
def sbm(n_blocks, block_size, p, q, rng=None):
""" (Symmetric) Stochastic Block Model
......@@ -22,6 +26,8 @@ def sbm(n_blocks, block_size, p, q, rng=None):
Probability for intra-community edge.
q : float
Probability for inter-community edge.
rng : numpy.random.RandomState, optional
Random number generator.
Returns
-------
......@@ -49,9 +55,12 @@ def sbm(n_blocks, block_size, p, q, rng=None):
adj = sp.sparse.triu(a) + sp.sparse.triu(a, 1).transpose()
return adj
class SBMMixture:
""" Symmetric Stochastic Block Model Mixture
Please refer to Appendix C of "Supervised Community Detection with Hierarchical Graph Neural Networks" (https://arxiv.org/abs/1705.08415) for details.
class SBMMixtureDataset(DGLDataset):
r""" Symmetric Stochastic Block Model Mixture
Reference: Appendix C of "Supervised Community Detection with Hierarchical
Graph Neural Networks" (https://arxiv.org/abs/1705.08415).
Parameters
----------
......@@ -62,40 +71,123 @@ class SBMMixture:
n_communities : int
Number of communities.
k : int, optional
Multiplier.
Multiplier. Default: 2
avg_deg : int, optional
Average degree.
Average degree. Default: 3
pq : list of pair of nonnegative float or str, optional
Random densities.
Random densities. This parameter is for future extension,
for now it's always using the default value.
Default: Appendix_C
rng : numpy.random.RandomState, optional
Random number generator.
Random number generator. If not given, it's numpy.random.RandomState() with `seed=None`,
which read data from /dev/urandom (or the Windows analogue) if available or seed from
the clock otherwise.
Default: None
Raises
------
RuntimeError is raised if pq is not a list or string.
Examples
--------
>>> data = SBMMixtureDataset(n_graphs=16, n_nodes=10000, n_communities=2)
>>> from torch.utils.data import DataLoader
>>> dataloader = DataLoader(data, batch_size=1, collate_fn=data.collate_fn)
>>> for graph, line_graph, graph_degrees, line_graph_degrees, pm_pd in dataloader:
... # your code here
"""
def __init__(self, n_graphs, n_nodes, n_communities,
k=2, avg_deg=3, pq='Appendix C', rng=None):
def __init__(self,
n_graphs,
n_nodes,
n_communities,
k=2,
avg_deg=3,
pq='Appendix_C',
rng=None):
self._n_graphs = n_graphs
self._n_nodes = n_nodes
self._n_communities = n_communities
assert n_nodes % n_communities == 0
block_size = n_nodes // n_communities
self._block_size = n_nodes // n_communities
self._k = k
self._avg_deg = avg_deg
self._pq = pq
self._rng = rng
super(SBMMixtureDataset, self).__init__(name='sbmmixture',
hash_key=(n_graphs, n_nodes, n_communities, k, avg_deg, pq, rng))
def process(self):
pq = self._pq
if type(pq) is list:
assert len(pq) == n_graphs
assert len(pq) == self._n_graphs
elif type(pq) is str:
generator = {'Appendix C' : self._appendix_c}[pq]
pq = [generator() for i in range(n_graphs)]
generator = {'Appendix_C': self._appendix_c}[pq]
pq = [generator() for _ in range(self._n_graphs)]
else:
raise RuntimeError()
self._gs = [convert.graph(sbm(n_communities, block_size, *x)) for x in pq]
self._lgs = [g.line_graph(backtracking=False) for g in self._gs]
self._g_degs = [g.in_degrees().float() for g in self._gs]
self._lg_degs = [lg.in_degrees().float() for lg in self._lgs]
self._pm_pds = list(zip(*[g.edges() for g in self._gs]))[0]
self._graphs = [dgl_graph(sbm(self._n_communities, self._block_size, *x)) for x in pq]
self._line_graphs = [g.line_graph(backtracking=False) for g in self._graphs]
in_degrees = lambda g: g.in_degrees().float()
self._graph_degrees = [in_degrees(g) for g in self._graphs]
self._line_graph_degrees = [in_degrees(lg) for lg in self._line_graphs]
self._pm_pds = list(zip(*[g.edges() for g in self._graphs]))[0]
def has_cache(self):
graph_path = os.path.join(self.save_path, 'graphs_{}.bin'.format(self.hash))
line_graph_path = os.path.join(self.save_path, 'line_graphs_{}.bin'.format(self.hash))
info_path = os.path.join(self.save_path, 'info_{}.pkl'.format(self.hash))
return os.path.exists(graph_path) and \
os.path.exists(line_graph_path) and \
os.path.exists(info_path)
def save(self):
graph_path = os.path.join(self.save_path, 'graphs_{}.bin'.format(self.hash))
line_graph_path = os.path.join(self.save_path, 'line_graphs_{}.bin'.format(self.hash))
info_path = os.path.join(self.save_path, 'info_{}.pkl'.format(self.hash))
save_graphs(graph_path, self._graphs)
save_graphs(line_graph_path, self._line_graphs)
save_info(info_path, {'graph_degree': self._graph_degrees,
'line_graph_degree': self._line_graph_degrees,
'pm_pds': self._pm_pds})
def load(self):
graph_path = os.path.join(self.save_path, 'graphs_{}.bin'.format(self.hash))
line_graph_path = os.path.join(self.save_path, 'line_graphs_{}.bin'.format(self.hash))
info_path = os.path.join(self.save_path, 'info_{}.pkl'.format(self.hash))
self._graphs, _ = load_graphs(graph_path)
self._line_graphs, _ = load_graphs(line_graph_path)
info = load_info(info_path)
self._graph_degrees = info['graph_degree']
self._line_graph_degrees = info['line_graph_degree']
self._pm_pds = info['pm_pds']
def __len__(self):
return len(self._gs)
r"""Number of graphs in the dataset."""
return len(self._graphs)
def __getitem__(self, idx):
return self._gs[idx], self._lgs[idx], \
self._g_degs[idx], self._lg_degs[idx], self._pm_pds[idx]
r""" Get one example by index
Parameters
----------
idx : int
Item index
Returns
-------
graph : dgl.DGLGraph
The original graph
line_graph : dgl.DGLGraph
The line graph of `graph`
graph_degree : numpy.ndarray
In degrees for each node in `graph`
line_graph_degree : numpy.ndarray
In degrees for each node in `line_graph`
pm_pd : numpy.ndarray
Edge indicator matrices Pm and Pd
"""
return self._graphs[idx], self._line_graphs[idx], \
self._graph_degrees[idx], self._line_graph_degrees[idx], self._pm_pds[idx]
def _appendix_c(self):
q = npr.uniform(0, self._avg_deg - math.sqrt(self._avg_deg))
......@@ -106,6 +198,36 @@ class SBMMixture:
return q, p
def collate_fn(self, x):
r""" The `collate` function for dataloader
Parameters
----------
x : tuple
a batch of data that contains
graph : dgl.DGLGraph
The original graph
line_graph : dgl.DGLGraph
The line graph of `graph`
graph_degree : numpy.ndarray
In degrees for each node in `graph`
line_graph_degree : numpy.ndarray
In degrees for each node in `line_graph`
pm_pd : numpy.ndarray
Edge indicator matrices Pm and Pd
Returns
-------
g_batch : dgl.DGLGraph
Batched graphs
lg_batch : dgl.DGLGraph
Batched line graphs
degg_batch : numpy.ndarray
A batch of in degrees for each node in `g_batch`
deglg_batch : numpy.ndarray
A batch of in degrees for each node in `lg_batch`
pm_pd_batch : numpy.ndarray
A batch of edge indicator matrices Pm and Pd
"""
g, lg, deg_g, deg_lg, pm_pd = zip(*x)
g_batch = batch.batch(g)
lg_batch = batch.batch(lg)
......@@ -113,3 +235,6 @@ class SBMMixture:
deglg_batch = np.concatenate(deg_lg, axis=0)
pm_pd_batch = np.concatenate([x + i * self._n_nodes for i, x in enumerate(pm_pd)], axis=0)
return g_batch, lg_batch, degg_batch, deglg_batch, pm_pd_batch
SBMMixture = SBMMixtureDataset
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment