Unverified Commit 0a78dbe1 authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

Three DGL-GCN implementations (#35)

* synthetic dataset

* some fix

* gcn readme
parent 2c489fad
Graph Convolutional Networks (GCN)
============
Paper link: [https://arxiv.org/abs/1609.02907](https://arxiv.org/abs/1609.02907)
Author's code repo: [https://github.com/tkipf/gcn](https://github.com/tkipf/gcn)
The folder contains three different implementations using DGL.
Naive GCN (gcn.py)
-------
The model is defined in the finest granularity (aka on *one* edge and *one* node).
* The message function `gcn_msg` computes the message for one edge. It simply returns the `h` representation of the source node.
```python
def gcn_msg(src, edge):
# src['h'] is a tensor of shape (D,). D is the feature length.
return src['h']
```
* The reduce function `gcn_reduce` accumulates the incoming messages for one node. The `msgs` argument is a list of all the messages. In GCN, the incoming messages are summed up.
```python
def gcn_reduce(node, msgs):
# msgs is a list of in-coming messages.
return sum(msgs)
```
* The update function `NodeUpdateModule` computes the new new node representation `h` using non-linear transformation on the reduced messages.
```python
class NodeUpdateModule(nn.Module):
def __init__(self, in_feats, out_feats, activation=None):
super(NodeUpdateModule, self).__init__()
self.linear = nn.Linear(in_feats, out_feats)
self.activation = activation
def forward(self, node, accum):
# accum is a tensor of shape (D,).
h = self.linear(accum)
if self.activation:
h = self.activation(h)
return {'h' : h}
```
After defining the functions on each node/edge, the message passing is triggered by calling `update_all` on the DGLGraph object (in GCN module).
Batched GCN (gcn_batch.py)
-----------
Defining the model on only one node and edge makes it hard to fully utilize GPUs. As a result, we allow users to define model on a *batch of* nodes and edges.
* The message function `gcn_msg` computes the message for a batch of edges. Here, the `src` argument is the batched representation of the source endpoints of the edges. The function simply returns the source node representations.
```python
def gcn_msg(src, edge):
# src is a tensor of shape (B, D). B is the number of edges being batched.
return src
```
* The reduce function `gcn_reduce` also accumulates messages for a batch of nodes. We batch the messages on the second dimension fo the `msgs` argument:
```python
def gcn_reduce(node, msgs):
# The msgs is a tensor of shape (B, deg, D). B is the number of nodes in the batch;
# deg is the number of messages; D is the message tensor dimension. DGL gaurantees
# that all the nodes in a batch have the same in-degrees (through "degree-bucketing").
# Reduce on the second dimension is equal to sum up all the in-coming messages.
return torch.sum(msgs, 1)
```
* The update module is similar. The first dimension of each tensor is the batch dimension. Since PyTorch operation is usually aware of the batch dimension, the code is the same as the naive GCN.
Triggering message passing is also similar. User needs to set `batchable=True` to indicate that the functions all support batching.
```python
self.g.update_all(gcn_msg, gcn_reduce, layer, batchable=True)`
```
Batched GCN with spMV optimization (gcn_spmv.py)
-----------
Batched computation is much more efficient than naive vertex-centric approach, but is still not ideal. For example, the batched message function needs to look up source node data and save it on edges. Such kind of lookups is very common and incurs extra memory copy operations. In fact, the message and reduce phase of GCN model can be fused into one sparse-matrix-vector multiplication (spMV). Therefore, DGL provides many built-in message/reduce functions so we can figure out the chance of optimization. In gcn_spmv.py, user only needs to write update module and trigger the message passing as follows:
```python
self.g.update_all('from_src', 'sum', layer, batchable=True)
```
Here, `'from_src'` and `'sum'` are the builtin message and reduce function.
......@@ -10,7 +10,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl import DGLGraph
from dgl.data import load_cora, load_citeseer, load_pubmed
from dgl.data import register_data_args, load_data
def gcn_msg(src, edge):
return src['h']
......@@ -58,18 +58,11 @@ class GCN(nn.Module):
if self.dropout:
self.g.nodes[n]['h'] = F.dropout(g.nodes[n]['h'], p=self.dropout)
self.g.update_all(gcn_msg, gcn_reduce, layer)
return torch.cat([self.g.nodes[n]['h'] for n in train_nodes])
return torch.cat([torch.unsqueeze(self.g.nodes[n]['h'], 0) for n in train_nodes])
def main(args):
# load and preprocess dataset
if args.dataset == 'cora':
data = load_cora()
elif args.dataset == 'citeseer':
data = load_citeseer()
elif args.dataset == 'pubmed':
data = load_pubmed()
else:
raise RuntimeError('Error dataset: {}'.format(args.dataset))
data = load_data(args)
# features of each samples
features = {}
......@@ -90,9 +83,8 @@ def main(args):
else:
cuda = True
torch.cuda.set_device(args.gpu)
features = features.cuda()
features = {k : v.cuda() for k, v in features.items()}
labels = labels.cuda()
mask = mask.cuda()
# create GCN model
model = GCN(data.graph,
......@@ -131,15 +123,14 @@ def main(args):
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='GCN')
parser.add_argument("--dataset", type=str, required=True,
help="dataset")
register_data_args(parser)
parser.add_argument("--dropout", type=float, default=0,
help="dropout probability")
parser.add_argument("--gpu", type=int, default=-1,
help="gpu")
parser.add_argument("--lr", type=float, default=1e-3,
help="learning rate")
parser.add_argument("--n-epochs", type=int, default=20,
parser.add_argument("--n-epochs", type=int, default=10,
help="number of training epochs")
parser.add_argument("--n-hidden", type=int, default=16,
help="number of hidden gcn units")
......
......@@ -13,7 +13,7 @@ import torch.nn as nn
import torch.nn.functional as F
import dgl
from dgl import DGLGraph
from dgl.data import load_cora, load_citeseer, load_pubmed
from dgl.data import register_data_args, load_data
def gcn_msg(src, edge):
return src
......@@ -65,14 +65,8 @@ class GCN(nn.Module):
def main(args):
# load and preprocess dataset
if args.dataset == 'cora':
data = load_cora()
elif args.dataset == 'citeseer':
data = load_citeseer()
elif args.dataset == 'pubmed':
data = load_pubmed()
else:
raise RuntimeError('Error dataset: {}'.format(args.dataset))
data = load_data(args)
features = torch.FloatTensor(data.features)
labels = torch.LongTensor(data.labels)
mask = torch.ByteTensor(data.train_mask)
......@@ -129,8 +123,7 @@ def main(args):
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='GCN')
parser.add_argument("--dataset", type=str, required=True,
help="dataset")
register_data_args(parser)
parser.add_argument("--dropout", type=float, default=0,
help="dropout probability")
parser.add_argument("--gpu", type=int, default=-1,
......
......@@ -13,7 +13,7 @@ import torch.nn as nn
import torch.nn.functional as F
import dgl
from dgl import DGLGraph
from dgl.data import load_cora, load_citeseer, load_pubmed
from dgl.data import register_data_args, load_data
class NodeUpdateModule(nn.Module):
def __init__(self, in_feats, out_feats, activation=None):
......@@ -59,14 +59,8 @@ class GCN(nn.Module):
def main(args):
# load and preprocess dataset
if args.dataset == 'cora':
data = load_cora()
elif args.dataset == 'citeseer':
data = load_citeseer()
elif args.dataset == 'pubmed':
data = load_pubmed()
else:
raise RuntimeError('Error dataset: {}'.format(args.dataset))
data = load_data(args)
features = torch.FloatTensor(data.features)
labels = torch.LongTensor(data.labels)
mask = torch.ByteTensor(data.train_mask)
......@@ -123,8 +117,7 @@ def main(args):
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='GCN')
parser.add_argument("--dataset", type=str, required=True,
help="dataset")
register_data_args(parser)
parser.add_argument("--dropout", type=float, default=0,
help="dropout probability")
parser.add_argument("--gpu", type=int, default=-1,
......
"""Data related package."""
from __future__ import absolute_import
from . import citation_graph as citegrh
from .utils import *
load_cora = citegrh.load_cora
load_citeseer = citegrh.load_citeseer
load_pubmed = citegrh.load_pubmed
def register_data_args(parser):
parser.add_argument("--dataset", type=str, required=True,
help="The input dataset.")
citegrh.register_args(parser)
from .utils import *
def load_data(args):
if args.dataset == 'cora':
return citegrh.load_cora()
elif args.dataset == 'citeseer':
return citegrh.load_citeseer()
elif args.dataset == 'pubmed':
return citegrh.load_pubmed()
elif args.dataset == 'syn':
return citegrh.load_synthetic(args)
else:
raise ValueError('Unknown dataset: {}'.format(args.dataset))
......@@ -3,6 +3,8 @@
(lingfan): following dataset loading and preprocessing code from tkipf/gcn
https://github.com/tkipf/gcn/blob/master/gcn/utils.py
"""
from __future__ import absolute_import
import numpy as np
import pickle as pkl
import networkx as nx
......@@ -17,7 +19,7 @@ _urls = {
'pubmed' : 'https://www.dropbox.com/s/fj5q6pi66xhymcm/pubmed.zip?dl=1',
}
class GCNDataset:
class GCNDataset(object):
def __init__(self, name):
self.name = name
self.dir = get_download_dir()
......@@ -117,7 +119,7 @@ def _preprocess_features(features):
r_inv[np.isinf(r_inv)] = 0.
r_mat_inv = sp.diags(r_inv)
features = r_mat_inv.dot(features)
return features.todense()
return np.array(features.todense())
def _parse_index_file(filename):
"""Parse index file."""
......@@ -146,3 +148,94 @@ def load_pubmed():
data = GCNDataset('pubmed')
data.load()
return data
class GCNSyntheticDataset(object):
def __init__(self,
graph_generator,
num_feats=500,
num_classes=10,
train_ratio=1.,
val_ratio=0.,
test_ratio=0.,
seed=None):
rng = np.random.RandomState(seed)
# generate graph
self.graph = graph_generator(seed)
num_nodes = self.graph.number_of_nodes()
# generate features
#self.features = rng.randn(num_nodes, num_feats).astype(np.float32)
self.features = np.zeros((num_nodes, num_feats), dtype=np.float32)
# generate labels
self.labels = rng.randint(num_classes, size=num_nodes)
onehot_labels = np.zeros((num_nodes, num_classes), dtype=np.float32)
onehot_labels[np.arange(num_nodes), self.labels] = 1.
self.onehot_labels = onehot_labels
self.num_labels = num_classes
# generate masks
ntrain = int(num_nodes * train_ratio)
nval = int(num_nodes * val_ratio)
ntest = int(num_nodes * test_ratio)
mask_array = np.zeros((num_nodes,), dtype=np.int32)
mask_array[0:ntrain] = 1
mask_array[ntrain:ntrain+nval] = 2
mask_array[ntrain+nval:ntrain+nval+ntest] = 3
rng.shuffle(mask_array)
self.train_mask = (mask_array == 1).astype(np.int32)
self.val_mask = (mask_array == 2).astype(np.int32)
self.test_mask = (mask_array == 3).astype(np.int32)
print('Finished synthetic dataset generation.')
print(' NumNodes: {}'.format(self.graph.number_of_nodes()))
print(' NumEdges: {}'.format(self.graph.number_of_edges()))
print(' NumFeats: {}'.format(self.features.shape[1]))
print(' NumClasses: {}'.format(self.num_labels))
print(' NumTrainingSamples: {}'.format(len(np.nonzero(self.train_mask)[0])))
print(' NumValidationSamples: {}'.format(len(np.nonzero(self.val_mask)[0])))
print(' NumTestSamples: {}'.format(len(np.nonzero(self.test_mask)[0])))
def get_gnp_generator(args):
n = args.syn_gnp_n
p = (2 * np.log(n) / n) if args.syn_gnp_p == 0. else args.syn_gnp_p
def _gen(seed):
return nx.fast_gnp_random_graph(n, p, seed, True)
return _gen
def load_synthetic(args):
ty = args.syn_type
if ty == 'gnp':
gen = get_gnp_generator(args)
else:
raise ValueError('Unknown graph generator type: {}'.format(ty))
return GCNSyntheticDataset(
gen,
args.syn_nfeats,
args.syn_nclasses,
args.syn_train_ratio,
args.syn_val_ratio,
args.syn_test_ratio,
args.syn_seed)
def register_args(parser):
# Args for synthetic graphs.
parser.add_argument('--syn-type', type=str, default='gnp',
help='Type of the synthetic graph generator')
parser.add_argument('--syn-nfeats', type=int, default=500,
help='Number of node features')
parser.add_argument('--syn-nclasses', type=int, default=10,
help='Number of output classes')
parser.add_argument('--syn-train-ratio', type=float, default=.1,
help='Ratio of training nodes')
parser.add_argument('--syn-val-ratio', type=float, default=.2,
help='Ratio of validation nodes')
parser.add_argument('--syn-test-ratio', type=float, default=.5,
help='Ratio of testing nodes')
# Args for GNP generator
parser.add_argument('--syn-gnp-n', type=int, default=1000,
help='n in gnp random graph')
parser.add_argument('--syn-gnp-p', type=float, default=0.0,
help='p in gnp random graph')
parser.add_argument('--syn-seed', type=int, default=42,
help='random seed')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment