Unverified Commit 2758c249 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[NN] Fix GCN module (#99)

1. Update `examples/pytorch/gcn` and `python/dgl/nn/pytorch` based on the latest APIs
2. Add full support for dropout in `examples/pytorch/gcn` and `python/dgl/nn/pytorch`
3. Rename `GCN` class in `python/dgl/nn/pytorch` to be `GraphConvolutionLayer` class
4. Make node field an argument that can be configured by users in GraphConvolutionLayer

Note that adjacency normalization has not been supported yet in the examples. 
parent d0ea98be
Graph Convolutional Networks (GCN)
============
Paper link: [https://arxiv.org/abs/1609.02907](https://arxiv.org/abs/1609.02907)
Author's code repo: [https://github.com/tkipf/gcn](https://github.com/tkipf/gcn)
- Paper link: [https://arxiv.org/abs/1609.02907](https://arxiv.org/abs/1609.02907)
- Author's code repo: [https://github.com/tkipf/gcn](https://github.com/tkipf/gcn). Note that the original code is
implemented with Tensorflow for the paper.
The folder contains two different implementations using DGL.
......@@ -14,28 +15,29 @@ Defining the model on only one node and edge makes it hard to fully utilize GPUs
```python
def gcn_msg(src, edge):
# src is a tensor of shape (B, D). B is the number of edges being batched.
return src
return {'m' : src['h']}
```
* The reduce function `gcn_reduce` also accumulates messages for a batch of nodes. We batch the messages on the second dimension fo the `msgs` argument:
* The reduce function `gcn_reduce` also accumulates messages for a batch of nodes. We batch the messages on the second dimension for the `msgs` argument,
which for example can correspond to the neighbors of the nodes:
```python
def gcn_reduce(node, msgs):
# The msgs is a tensor of shape (B, deg, D). B is the number of nodes in the batch;
# deg is the number of messages; D is the message tensor dimension. DGL gaurantees
# that all the nodes in a batch have the same in-degrees (through "degree-bucketing").
# Reduce on the second dimension is equal to sum up all the in-coming messages.
return torch.sum(msgs, 1)
return {'h' : torch.sum(msgs['m'], 1)}
```
* The update module is similar. The first dimension of each tensor is the batch dimension. Since PyTorch operation is usually aware of the batch dimension, the code is the same as the naive GCN.
Triggering message passing is also similar. User needs to set `batchable=True` to indicate that the functions all support batching.
Triggering message passing is also similar.
```python
self.g.update_all(gcn_msg, gcn_reduce, layer, batchable=True)`
self.g.update_all(gcn_msg, gcn_reduce, layer)`
```
Batched GCN with spMV optimization (gcn_spmv.py)
-----------
Batched computation is much more efficient than naive vertex-centric approach, but is still not ideal. For example, the batched message function needs to look up source node data and save it on edges. Such kind of lookups is very common and incurs extra memory copy operations. In fact, the message and reduce phase of GCN model can be fused into one sparse-matrix-vector multiplication (spMV). Therefore, DGL provides many built-in message/reduce functions so we can figure out the chance of optimization. In gcn_spmv.py, user only needs to write update module and trigger the message passing as follows:
```python
self.g.update_all('from_src', 'sum', layer, batchable=True)
self.g.update_all(fn.copy_src(src='h', out='m'), fn.sum(msg='m', out='h'), layer)
```
Here, `'from_src'` and `'sum'` are the builtin message and reduce function.
Here, `'fn.copy_src'` and `'fn.sum'` are the builtin message and reduce functions that perform the same operations as `gcn_msg` and `gcn_reduce` in gcn.py.
......@@ -11,7 +11,6 @@ import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
from dgl import DGLGraph
from dgl.data import register_data_args, load_data
......@@ -24,6 +23,7 @@ def gcn_reduce(node, msgs):
class NodeApplyModule(nn.Module):
def __init__(self, in_feats, out_feats, activation=None):
super(NodeApplyModule, self).__init__()
self.linear = nn.Linear(in_feats, out_feats)
self.activation = activation
......@@ -31,6 +31,7 @@ class NodeApplyModule(nn.Module):
h = self.linear(node['h'])
if self.activation:
h = self.activation(h)
return {'h' : h}
class GCN(nn.Module):
......@@ -44,27 +45,36 @@ class GCN(nn.Module):
dropout):
super(GCN, self).__init__()
self.g = g
self.dropout = dropout
if dropout:
self.dropout = nn.Dropout(p=dropout)
else:
self.dropout = 0.
# input layer
self.layers = nn.ModuleList([NodeApplyModule(in_feats, n_hidden, activation)])
# hidden layers
for i in range(n_layers - 1):
self.layers.append(NodeApplyModule(n_hidden, n_hidden, activation))
# output layer
self.layers.append(NodeApplyModule(n_hidden, n_classes))
def forward(self, features):
self.g.set_n_repr({'h' : features})
for layer in self.layers:
# apply dropout
if self.dropout:
g.apply_nodes(apply_node_func=
lambda node: F.dropout(node['h'], p=self.dropout))
self.g.apply_nodes(apply_node_func=
lambda node: {'h': self.dropout(node['h'])})
self.g.update_all(gcn_msg, gcn_reduce, layer)
return self.g.pop_n_repr('h')
def main(args):
# load and preprocess dataset
# Todo: adjacency normalization
data = load_data(args)
features = torch.FloatTensor(data.features)
......
......@@ -11,7 +11,6 @@ import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
import dgl.function as fn
from dgl import DGLGraph
from dgl.data import register_data_args, load_data
......@@ -19,6 +18,7 @@ from dgl.data import register_data_args, load_data
class NodeApplyModule(nn.Module):
def __init__(self, in_feats, out_feats, activation=None):
super(NodeApplyModule, self).__init__()
self.linear = nn.Linear(in_feats, out_feats)
self.activation = activation
......@@ -26,7 +26,8 @@ class NodeApplyModule(nn.Module):
h = self.linear(node['h'])
if self.activation:
h = self.activation(h)
return {'h' : h}
return {'h': h}
class GCN(nn.Module):
def __init__(self,
......@@ -39,22 +40,30 @@ class GCN(nn.Module):
dropout):
super(GCN, self).__init__()
self.g = g
self.dropout = dropout
if dropout:
self.dropout = nn.Dropout(p=dropout)
else:
self.dropout = 0.
# input layer
self.layers = nn.ModuleList([NodeApplyModule(in_feats, n_hidden, activation)])
# hidden layers
for i in range(n_layers - 1):
self.layers.append(NodeApplyModule(n_hidden, n_hidden, activation))
# output layer
self.layers.append(NodeApplyModule(n_hidden, n_classes))
def forward(self, features):
self.g.set_n_repr({'h' : features})
for layer in self.layers:
# apply dropout
if self.dropout:
g.apply_nodes(apply_node_func=
lambda node: F.dropout(node['h'], p=self.dropout))
self.g.apply_nodes(apply_node_func=
lambda node: {'h': self.dropout(node['h'])})
self.g.update_all(fn.copy_src(src='h', out='m'),
fn.sum(msg='m', out='h'),
layer)
......@@ -62,6 +71,7 @@ class GCN(nn.Module):
def main(args):
# load and preprocess dataset
# Todo: adjacency normalization
data = load_data(args)
features = torch.FloatTensor(data.features)
......
from .gcn import GCN
from .gcn import GraphConvolutionLayer
......@@ -10,44 +10,61 @@ import torch.nn as nn
from ... import function as fn
from ...base import ALL, is_all
class NodeUpdateModule(nn.Module):
def __init__(self, in_feats, out_feats, activation=None):
def __init__(self, node_field, in_feats, out_feats, activation=None):
super(NodeUpdateModule, self).__init__()
self.node_field = node_field
self.linear = nn.Linear(in_feats, out_feats)
self.activation = activation
self.attribute = None
def forward(self, node):
h = self.linear(node['accum'])
h = self.linear(node[self.node_field])
if self.activation:
h = self.activation(h)
if self.attribute:
return {self.attribute: h}
else:
return h
class GCN(nn.Module):
return {self.node_field: h}
class GraphConvolutionLayer(nn.Module):
"""Single graph convolution layer as in https://arxiv.org/abs/1609.02907."""
def __init__(self,
node_field,
in_feats,
out_feats,
activation,
dropout=0):
super(GCN, self).__init__()
self.dropout = dropout
"""
node_filed: hashable keys for node features, e.g. 'h'
msg_field: hashable keys for message features, e.g. 'm'. In GCN, this is
just AH, where A is the adjacency matrix and H is current node features.
"""
super(GraphConvolutionLayer, self).__init__()
self.node_field = node_field
if dropout:
self.dropout = nn.Dropout(p=dropout)
else:
self.dropout = 0.
# input layer
self.update_func = NodeUpdateModule(in_feats, out_feats, activation)
self.update_func = NodeUpdateModule(node_field, in_feats, out_feats,
activation)
def forward(self, g, u=ALL, v=ALL):
if self.dropout:
g.apply_nodes(u, apply_node_func=
lambda node: {self.node_field: self.dropout(node[self.node_field])})
def forward(self, g, u=ALL, v=ALL, attribute=None):
if is_all(u) and is_all(v):
g.update_all(fn.copy_src(src=attribute),
fn.sum(out='accum'),
self.update_func,
batchable=True)
g.update_all(fn.copy_src(src=self.node_field, out='m'),
fn.sum(msg='m', out=self.node_field),
self.update_func)
else:
g.send_and_recv(u, v,
fn.copy_src(src=attribute),
fn.sum(out='accum'),
self.update_func,
batchable=True)
g.pop_n_repr('accum')
fn.copy_src(src=self.node_field, out='m'),
fn.sum(msg='m', out=self.node_field),
self.update_func)
return g
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment