Unverified Commit 2758c249 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[NN] Fix GCN module (#99)

1. Update `examples/pytorch/gcn` and `python/dgl/nn/pytorch` based on the latest APIs
2. Add full support for dropout in `examples/pytorch/gcn` and `python/dgl/nn/pytorch`
3. Rename `GCN` class in `python/dgl/nn/pytorch` to be `GraphConvolutionLayer` class
4. Make node field an argument that can be configured by users in GraphConvolutionLayer

Note that adjacency normalization has not been supported yet in the examples. 
parent d0ea98be
Graph Convolutional Networks (GCN) Graph Convolutional Networks (GCN)
============ ============
Paper link: [https://arxiv.org/abs/1609.02907](https://arxiv.org/abs/1609.02907) - Paper link: [https://arxiv.org/abs/1609.02907](https://arxiv.org/abs/1609.02907)
Author's code repo: [https://github.com/tkipf/gcn](https://github.com/tkipf/gcn) - Author's code repo: [https://github.com/tkipf/gcn](https://github.com/tkipf/gcn). Note that the original code is
implemented with Tensorflow for the paper.
The folder contains two different implementations using DGL. The folder contains two different implementations using DGL.
...@@ -14,28 +15,29 @@ Defining the model on only one node and edge makes it hard to fully utilize GPUs ...@@ -14,28 +15,29 @@ Defining the model on only one node and edge makes it hard to fully utilize GPUs
```python ```python
def gcn_msg(src, edge): def gcn_msg(src, edge):
# src is a tensor of shape (B, D). B is the number of edges being batched. # src is a tensor of shape (B, D). B is the number of edges being batched.
return src return {'m' : src['h']}
``` ```
* The reduce function `gcn_reduce` also accumulates messages for a batch of nodes. We batch the messages on the second dimension fo the `msgs` argument: * The reduce function `gcn_reduce` also accumulates messages for a batch of nodes. We batch the messages on the second dimension for the `msgs` argument,
which for example can correspond to the neighbors of the nodes:
```python ```python
def gcn_reduce(node, msgs): def gcn_reduce(node, msgs):
# The msgs is a tensor of shape (B, deg, D). B is the number of nodes in the batch; # The msgs is a tensor of shape (B, deg, D). B is the number of nodes in the batch;
# deg is the number of messages; D is the message tensor dimension. DGL gaurantees # deg is the number of messages; D is the message tensor dimension. DGL gaurantees
# that all the nodes in a batch have the same in-degrees (through "degree-bucketing"). # that all the nodes in a batch have the same in-degrees (through "degree-bucketing").
# Reduce on the second dimension is equal to sum up all the in-coming messages. # Reduce on the second dimension is equal to sum up all the in-coming messages.
return torch.sum(msgs, 1) return {'h' : torch.sum(msgs['m'], 1)}
``` ```
* The update module is similar. The first dimension of each tensor is the batch dimension. Since PyTorch operation is usually aware of the batch dimension, the code is the same as the naive GCN. * The update module is similar. The first dimension of each tensor is the batch dimension. Since PyTorch operation is usually aware of the batch dimension, the code is the same as the naive GCN.
Triggering message passing is also similar. User needs to set `batchable=True` to indicate that the functions all support batching. Triggering message passing is also similar.
```python ```python
self.g.update_all(gcn_msg, gcn_reduce, layer, batchable=True)` self.g.update_all(gcn_msg, gcn_reduce, layer)`
``` ```
Batched GCN with spMV optimization (gcn_spmv.py) Batched GCN with spMV optimization (gcn_spmv.py)
----------- -----------
Batched computation is much more efficient than naive vertex-centric approach, but is still not ideal. For example, the batched message function needs to look up source node data and save it on edges. Such kind of lookups is very common and incurs extra memory copy operations. In fact, the message and reduce phase of GCN model can be fused into one sparse-matrix-vector multiplication (spMV). Therefore, DGL provides many built-in message/reduce functions so we can figure out the chance of optimization. In gcn_spmv.py, user only needs to write update module and trigger the message passing as follows: Batched computation is much more efficient than naive vertex-centric approach, but is still not ideal. For example, the batched message function needs to look up source node data and save it on edges. Such kind of lookups is very common and incurs extra memory copy operations. In fact, the message and reduce phase of GCN model can be fused into one sparse-matrix-vector multiplication (spMV). Therefore, DGL provides many built-in message/reduce functions so we can figure out the chance of optimization. In gcn_spmv.py, user only needs to write update module and trigger the message passing as follows:
```python ```python
self.g.update_all('from_src', 'sum', layer, batchable=True) self.g.update_all(fn.copy_src(src='h', out='m'), fn.sum(msg='m', out='h'), layer)
``` ```
Here, `'from_src'` and `'sum'` are the builtin message and reduce function. Here, `'fn.copy_src'` and `'fn.sum'` are the builtin message and reduce functions that perform the same operations as `gcn_msg` and `gcn_reduce` in gcn.py.
...@@ -11,7 +11,6 @@ import time ...@@ -11,7 +11,6 @@ import time
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import dgl
from dgl import DGLGraph from dgl import DGLGraph
from dgl.data import register_data_args, load_data from dgl.data import register_data_args, load_data
...@@ -24,6 +23,7 @@ def gcn_reduce(node, msgs): ...@@ -24,6 +23,7 @@ def gcn_reduce(node, msgs):
class NodeApplyModule(nn.Module): class NodeApplyModule(nn.Module):
def __init__(self, in_feats, out_feats, activation=None): def __init__(self, in_feats, out_feats, activation=None):
super(NodeApplyModule, self).__init__() super(NodeApplyModule, self).__init__()
self.linear = nn.Linear(in_feats, out_feats) self.linear = nn.Linear(in_feats, out_feats)
self.activation = activation self.activation = activation
...@@ -31,6 +31,7 @@ class NodeApplyModule(nn.Module): ...@@ -31,6 +31,7 @@ class NodeApplyModule(nn.Module):
h = self.linear(node['h']) h = self.linear(node['h'])
if self.activation: if self.activation:
h = self.activation(h) h = self.activation(h)
return {'h' : h} return {'h' : h}
class GCN(nn.Module): class GCN(nn.Module):
...@@ -44,27 +45,36 @@ class GCN(nn.Module): ...@@ -44,27 +45,36 @@ class GCN(nn.Module):
dropout): dropout):
super(GCN, self).__init__() super(GCN, self).__init__()
self.g = g self.g = g
self.dropout = dropout
if dropout:
self.dropout = nn.Dropout(p=dropout)
else:
self.dropout = 0.
# input layer # input layer
self.layers = nn.ModuleList([NodeApplyModule(in_feats, n_hidden, activation)]) self.layers = nn.ModuleList([NodeApplyModule(in_feats, n_hidden, activation)])
# hidden layers # hidden layers
for i in range(n_layers - 1): for i in range(n_layers - 1):
self.layers.append(NodeApplyModule(n_hidden, n_hidden, activation)) self.layers.append(NodeApplyModule(n_hidden, n_hidden, activation))
# output layer # output layer
self.layers.append(NodeApplyModule(n_hidden, n_classes)) self.layers.append(NodeApplyModule(n_hidden, n_classes))
def forward(self, features): def forward(self, features):
self.g.set_n_repr({'h' : features}) self.g.set_n_repr({'h' : features})
for layer in self.layers: for layer in self.layers:
# apply dropout # apply dropout
if self.dropout: if self.dropout:
g.apply_nodes(apply_node_func= self.g.apply_nodes(apply_node_func=
lambda node: F.dropout(node['h'], p=self.dropout)) lambda node: {'h': self.dropout(node['h'])})
self.g.update_all(gcn_msg, gcn_reduce, layer) self.g.update_all(gcn_msg, gcn_reduce, layer)
return self.g.pop_n_repr('h') return self.g.pop_n_repr('h')
def main(args): def main(args):
# load and preprocess dataset # load and preprocess dataset
# Todo: adjacency normalization
data = load_data(args) data = load_data(args)
features = torch.FloatTensor(data.features) features = torch.FloatTensor(data.features)
......
...@@ -11,7 +11,6 @@ import time ...@@ -11,7 +11,6 @@ import time
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import dgl
import dgl.function as fn import dgl.function as fn
from dgl import DGLGraph from dgl import DGLGraph
from dgl.data import register_data_args, load_data from dgl.data import register_data_args, load_data
...@@ -19,6 +18,7 @@ from dgl.data import register_data_args, load_data ...@@ -19,6 +18,7 @@ from dgl.data import register_data_args, load_data
class NodeApplyModule(nn.Module): class NodeApplyModule(nn.Module):
def __init__(self, in_feats, out_feats, activation=None): def __init__(self, in_feats, out_feats, activation=None):
super(NodeApplyModule, self).__init__() super(NodeApplyModule, self).__init__()
self.linear = nn.Linear(in_feats, out_feats) self.linear = nn.Linear(in_feats, out_feats)
self.activation = activation self.activation = activation
...@@ -26,7 +26,8 @@ class NodeApplyModule(nn.Module): ...@@ -26,7 +26,8 @@ class NodeApplyModule(nn.Module):
h = self.linear(node['h']) h = self.linear(node['h'])
if self.activation: if self.activation:
h = self.activation(h) h = self.activation(h)
return {'h' : h}
return {'h': h}
class GCN(nn.Module): class GCN(nn.Module):
def __init__(self, def __init__(self,
...@@ -39,22 +40,30 @@ class GCN(nn.Module): ...@@ -39,22 +40,30 @@ class GCN(nn.Module):
dropout): dropout):
super(GCN, self).__init__() super(GCN, self).__init__()
self.g = g self.g = g
self.dropout = dropout
if dropout:
self.dropout = nn.Dropout(p=dropout)
else:
self.dropout = 0.
# input layer # input layer
self.layers = nn.ModuleList([NodeApplyModule(in_feats, n_hidden, activation)]) self.layers = nn.ModuleList([NodeApplyModule(in_feats, n_hidden, activation)])
# hidden layers # hidden layers
for i in range(n_layers - 1): for i in range(n_layers - 1):
self.layers.append(NodeApplyModule(n_hidden, n_hidden, activation)) self.layers.append(NodeApplyModule(n_hidden, n_hidden, activation))
# output layer # output layer
self.layers.append(NodeApplyModule(n_hidden, n_classes)) self.layers.append(NodeApplyModule(n_hidden, n_classes))
def forward(self, features): def forward(self, features):
self.g.set_n_repr({'h' : features}) self.g.set_n_repr({'h' : features})
for layer in self.layers: for layer in self.layers:
# apply dropout # apply dropout
if self.dropout: if self.dropout:
g.apply_nodes(apply_node_func= self.g.apply_nodes(apply_node_func=
lambda node: F.dropout(node['h'], p=self.dropout)) lambda node: {'h': self.dropout(node['h'])})
self.g.update_all(fn.copy_src(src='h', out='m'), self.g.update_all(fn.copy_src(src='h', out='m'),
fn.sum(msg='m', out='h'), fn.sum(msg='m', out='h'),
layer) layer)
...@@ -62,6 +71,7 @@ class GCN(nn.Module): ...@@ -62,6 +71,7 @@ class GCN(nn.Module):
def main(args): def main(args):
# load and preprocess dataset # load and preprocess dataset
# Todo: adjacency normalization
data = load_data(args) data = load_data(args)
features = torch.FloatTensor(data.features) features = torch.FloatTensor(data.features)
......
from .gcn import GCN from .gcn import GraphConvolutionLayer
...@@ -10,44 +10,61 @@ import torch.nn as nn ...@@ -10,44 +10,61 @@ import torch.nn as nn
from ... import function as fn from ... import function as fn
from ...base import ALL, is_all from ...base import ALL, is_all
class NodeUpdateModule(nn.Module): class NodeUpdateModule(nn.Module):
def __init__(self, in_feats, out_feats, activation=None): def __init__(self, node_field, in_feats, out_feats, activation=None):
super(NodeUpdateModule, self).__init__() super(NodeUpdateModule, self).__init__()
self.node_field = node_field
self.linear = nn.Linear(in_feats, out_feats) self.linear = nn.Linear(in_feats, out_feats)
self.activation = activation self.activation = activation
self.attribute = None
def forward(self, node): def forward(self, node):
h = self.linear(node['accum']) h = self.linear(node[self.node_field])
if self.activation: if self.activation:
h = self.activation(h) h = self.activation(h)
if self.attribute:
return {self.attribute: h}
else:
return h
class GCN(nn.Module): return {self.node_field: h}
class GraphConvolutionLayer(nn.Module):
"""Single graph convolution layer as in https://arxiv.org/abs/1609.02907."""
def __init__(self, def __init__(self,
node_field,
in_feats, in_feats,
out_feats, out_feats,
activation, activation,
dropout=0): dropout=0):
super(GCN, self).__init__() """
self.dropout = dropout node_filed: hashable keys for node features, e.g. 'h'
msg_field: hashable keys for message features, e.g. 'm'. In GCN, this is
just AH, where A is the adjacency matrix and H is current node features.
"""
super(GraphConvolutionLayer, self).__init__()
self.node_field = node_field
if dropout:
self.dropout = nn.Dropout(p=dropout)
else:
self.dropout = 0.
# input layer # input layer
self.update_func = NodeUpdateModule(in_feats, out_feats, activation) self.update_func = NodeUpdateModule(node_field, in_feats, out_feats,
activation)
def forward(self, g, u=ALL, v=ALL):
if self.dropout:
g.apply_nodes(u, apply_node_func=
lambda node: {self.node_field: self.dropout(node[self.node_field])})
def forward(self, g, u=ALL, v=ALL, attribute=None):
if is_all(u) and is_all(v): if is_all(u) and is_all(v):
g.update_all(fn.copy_src(src=attribute), g.update_all(fn.copy_src(src=self.node_field, out='m'),
fn.sum(out='accum'), fn.sum(msg='m', out=self.node_field),
self.update_func, self.update_func)
batchable=True)
else: else:
g.send_and_recv(u, v, g.send_and_recv(u, v,
fn.copy_src(src=attribute), fn.copy_src(src=self.node_field, out='m'),
fn.sum(out='accum'), fn.sum(msg='m', out=self.node_field),
self.update_func, self.update_func)
batchable=True)
g.pop_n_repr('accum')
return g return g
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment