Commit b1e8d95e authored by HQ's avatar HQ Committed by Minjie Wang
Browse files

[Doc] LGNN tutorial (#210)

* tutorial notebook added

* lg tutorial cleaned up

* dataset scaffold

* move dataloader to data

* fix model

* remove todo

* utils seperated

* [model]line graph new implementation + tutorial + binary sub graph dataset

* [tutorial] line graph sphinx scaffold

* [tutorial] lgnn tutorial improved

* [tutorial] remove notebook

* [tutorial] fix lg and gcn links

* [tutorial] fix random seed

* [tutorial]fix

* WIP

* code refactor done

* new mini dataset; remove utils code

* fix

* word fix

* fix link

* minor fix

* minor fix

* minor fix
parent 10e18ed9
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
from __future__ import absolute_import from __future__ import absolute_import
from . import citation_graph as citegrh from . import citation_graph as citegrh
from .citation_graph import CoraBinary
from .tree import * from .tree import *
from .utils import * from .utils import *
from .sbm import SBMMixture from .sbm import SBMMixture
......
...@@ -11,14 +11,22 @@ import networkx as nx ...@@ -11,14 +11,22 @@ import networkx as nx
import scipy.sparse as sp import scipy.sparse as sp
import os, sys import os, sys
import dgl
from .utils import download, extract_archive, get_download_dir, _get_dgl_url from .utils import download, extract_archive, get_download_dir, _get_dgl_url
_urls = { _urls = {
'cora' : 'dataset/cora.zip', 'cora' : 'dataset/cora.zip',
'citeseer' : 'dataset/citeseer.zip', 'citeseer' : 'dataset/citeseer.zip',
'pubmed' : 'dataset/pubmed.zip', 'pubmed' : 'dataset/pubmed.zip',
'cora_binary' : 'dataset/cora_binary.zip',
} }
def _pickle_load(pkl_file):
if sys.version_info > (3, 0):
return pkl.load(pkl_file, encoding='latin1')
else:
return pkl.load(pkl_file)
class CitationGraphDataset(object): class CitationGraphDataset(object):
def __init__(self, name): def __init__(self, name):
self.name = name self.name = name
...@@ -52,10 +60,7 @@ class CitationGraphDataset(object): ...@@ -52,10 +60,7 @@ class CitationGraphDataset(object):
objects = [] objects = []
for i in range(len(objnames)): for i in range(len(objnames)):
with open("{}/ind.{}.{}".format(root, self.name, objnames[i]), 'rb') as f: with open("{}/ind.{}.{}".format(root, self.name, objnames[i]), 'rb') as f:
if sys.version_info > (3, 0): objects.append(_pickle_load(f))
objects.append(pkl.load(f, encoding='latin1'))
else:
objects.append(pkl.load(f))
x, y, tx, ty, allx, ally, graph = tuple(objects) x, y, tx, ty, allx, ally, graph = tuple(objects)
test_idx_reorder = _parse_index_file("{}/ind.{}.test.index".format(root, self.name)) test_idx_reorder = _parse_index_file("{}/ind.{}.test.index".format(root, self.name))
...@@ -265,3 +270,68 @@ def register_args(parser): ...@@ -265,3 +270,68 @@ def register_args(parser):
help='p in gnp random graph') help='p in gnp random graph')
parser.add_argument('--syn-seed', type=int, default=42, parser.add_argument('--syn-seed', type=int, default=42,
help='random seed') help='random seed')
class CoraBinary(object):
"""A mini-dataset for binary classification task using Cora.
After loaded, it has following members:
graphs : list of :class:`~dgl.DGLGraph`
pmpds : list of :class:`scipy.sparse.coo_matrix`
labels : list of :class:`numpy.ndarray`
"""
def __init__(self):
self.dir = get_download_dir()
self.name = 'cora_binary'
self.zip_file_path='{}/{}.zip'.format(self.dir, self.name)
download(_get_dgl_url(_urls[self.name]), path=self.zip_file_path)
extract_archive(self.zip_file_path, '{}/{}'.format(self.dir, self.name))
self._load()
def _load(self):
root = '{}/{}'.format(self.dir, self.name)
# load graphs
self.graphs = []
with open("{}/graphs.txt".format(root), 'r') as f:
elist = []
for line in f.readlines():
if line.startswith('graph'):
if len(elist) != 0:
self.graphs.append(dgl.DGLGraph(elist))
elist = []
else:
u, v = line.strip().split(' ')
elist.append((int(u), int(v)))
if len(elist) != 0:
self.graphs.append(dgl.DGLGraph(elist))
with open("{}/pmpds.pkl".format(root), 'rb') as f:
self.pmpds = _pickle_load(f)
self.labels = []
with open("{}/labels.txt".format(root), 'r') as f:
cur = []
for line in f.readlines():
if line.startswith('graph'):
if len(cur) != 0:
self.labels.append(np.array(cur))
cur = []
else:
cur.append(int(line.strip()))
if len(cur) != 0:
self.labels.append(np.array(cur))
# sanity check
assert len(self.graphs) == len(self.pmpds)
assert len(self.graphs) == len(self.labels)
def __len__(self):
return len(self.graphs)
def __getitem__(self, i):
return (self.graphs[i], self.pmpds[i], self.labels[i])
@staticmethod
def collate_fn(batch):
graphs, pmpds, labels = zip(*batch)
batched_graphs = dgl.batch(graphs)
batched_pmpds = sp.block_diag(pmpds)
batched_labels = np.concatenate(labels, axis=0)
return batched_graphs, batched_pmpds, batched_labels
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
Graph Convolutional Network Graph Convolutional Network
==================================== ====================================
**Author:** Qi Huang, `Minjie Wang <https://jermainewang.github.io/>`_, **Author:** `Qi Huang <https://github.com/HQ01>`_, `Minjie Wang <https://jermainewang.github.io/>`_,
Yu Gai, Quan Gan, Zheng Zhang Yu Gai, Quan Gan, Zheng Zhang
This is a gentle introduction of using DGL to implement Graph Convolutional This is a gentle introduction of using DGL to implement Graph Convolutional
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment