Unverified Commit 565f0c88 authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

[WIP] [NN] Refactor NN package (#406)

* refactor graph conv

* docs & tests

* fix lint

* fix lint

* fix lint

* fix lint script

* fix lint

* Update

* Style fix

* Fix style

* Fix style

* Fix gpu case

* Fix for gpu case

* Hotfix edgesoftmax docs

* Handle repeated features

* Add docstring

* Set default arguments

* Remove dropout from nn.conv

* Fix

* add util fn for renaming

* revert gcn_spmv.py

* mx folder

* fix wierd bug

* fix mx

* fix lint
parent 8c750170
"""Torch modules for graph convolutions."""
# pylint: disable= no-member, arguments-differ
import torch as th
from torch import nn
from torch.nn import init
from ... import function as fn
from ...utils import get_ndata_name
__all__ = ['GraphConv']
class GraphConv(nn.Module):
r"""Apply graph convolution over an input signal.
Graph convolution is introduced in `GCN <https://arxiv.org/abs/1609.02907>`__
and can be described as below:
.. math::
h_i^{(l+1)} = \sigma(b^{(l)} + \sum_{j\in\mathcal{N}(i)}\frac{1}{c_{ij}}W^{(l)}h_j^{(l)})
where :math:`\mathcal{N}(i)` is the neighbor set of node :math:`i`. :math:`c_{ij}` is equal
to the product of the square root of node degrees:
:math:`\sqrt{|\mathcal{N}(i)|}\sqrt{|\mathcal{N}(j)|}`. :math:`\sigma` is an activation
function.
The model parameters are initialized as in the
`original implementation <https://github.com/tkipf/gcn/blob/master/gcn/layers.py>`__ where
the weight :math:`W^{(l)}` is initialized using Glorot uniform initialization
and the bias is initialized to be zero.
Notes
-----
Zero in degree nodes could lead to invalid normalizer. A common practice
to avoid this is to add a self-loop for each node in the graph, which
can be achieved by:
>>> g = ... # some DGLGraph
>>> g.add_edges(g.nodes(), g.nodes())
Parameters
----------
in_feats : int
Number of input features.
out_feats : int
Number of output features.
norm : bool, optional
If True, the normalizer :math:`c_{ij}` is applied. Default: ``True``.
bias : bool, optional
If True, adds a learnable bias to the output. Default: ``True``.
activation: callable activation function/layer or None, optional
If not None, applies an activation function to the updated node features.
Default: ``None``.
Attributes
----------
weight : torch.Tensor
The learnable weight tensor.
bias : torch.Tensor
The learnable bias tensor.
"""
def __init__(self,
in_feats,
out_feats,
norm=True,
bias=True,
activation=None):
super(GraphConv, self).__init__()
self._in_feats = in_feats
self._out_feats = out_feats
self._norm = norm
self._feat_name = "_gconv_feat"
self._msg_name = "_gconv_msg"
self.weight = nn.Parameter(th.Tensor(in_feats, out_feats))
if bias:
self.bias = nn.Parameter(th.Tensor(out_feats))
else:
self.register_parameter('bias', None)
self.reset_parameters()
self._activation = activation
def reset_parameters(self):
"""Reinitialize learnable parameters."""
init.xavier_uniform_(self.weight)
if self.bias is not None:
init.zeros_(self.bias)
def forward(self, feat, graph):
r"""Compute graph convolution.
Notes
-----
* Input shape: :math:`(N, *, \text{in_feats})` where * means any number of additional
dimensions, :math:`N` is the number of nodes.
* Output shape: :math:`(N, *, \text{out_feats})` where all but the last dimension are
the same shape as the input.
Parameters
----------
feat : torch.Tensor
The input feature
graph : DGLGraph
The graph.
Returns
-------
torch.Tensor
The output feature
"""
self._feat_name = get_ndata_name(graph, self._feat_name)
if self._norm:
norm = th.pow(graph.in_degrees().float(), -0.5)
shp = norm.shape + (1,) * (feat.dim() - 1)
norm = th.reshape(norm, shp).to(feat.device)
feat = feat * norm
if self._in_feats > self._out_feats:
# mult W first to reduce the feature size for aggregation.
feat = th.matmul(feat, self.weight)
graph.ndata[self._feat_name] = feat
graph.update_all(fn.copy_src(src=self._feat_name, out=self._msg_name),
fn.sum(msg=self._msg_name, out=self._feat_name))
rst = graph.ndata.pop(self._feat_name)
else:
# aggregate first then mult W
graph.ndata[self._feat_name] = feat
graph.update_all(fn.copy_src(src=self._feat_name, out=self._msg_name),
fn.sum(msg=self._msg_name, out=self._feat_name))
rst = graph.ndata.pop(self._feat_name)
rst = th.matmul(rst, self.weight)
if self._norm:
rst = rst * norm
if self.bias is not None:
rst = rst + self.bias
if self._activation is not None:
rst = self._activation(rst)
return rst
def extra_repr(self):
"""Set the extra representation of the module,
which will come into effect when printing the model.
"""
summary = 'in={_in_feats}, out={_out_feats}'
summary += ', normalization={_norm}'
summary += ', activation={_activation}'
return summary.format(**self.__dict__)
"""
Semi-Supervised Classification with Graph Convolutional Networks
Paper: https://arxiv.org/abs/1609.02907
Code: https://github.com/tkipf/gcn
GCN with SPMV specialization.
"""
import torch.nn as nn
from ... import function as fn
from ...base import ALL, is_all
class NodeUpdateModule(nn.Module):
def __init__(self, node_field, in_feats, out_feats, activation=None):
super(NodeUpdateModule, self).__init__()
self.node_field = node_field
self.linear = nn.Linear(in_feats, out_feats)
self.activation = activation
def forward(self, node):
h = self.linear(node[self.node_field])
if self.activation:
h = self.activation(h)
return {self.node_field: h}
class GraphConvolutionLayer(nn.Module):
"""Single graph convolution layer as in https://arxiv.org/abs/1609.02907."""
def __init__(self,
node_field,
in_feats,
out_feats,
activation,
dropout=0):
"""
node_filed: hashable keys for node features, e.g. 'h'
msg_field: hashable keys for message features, e.g. 'm'. In GCN, this is
just AH, where A is the adjacency matrix and H is current node features.
"""
super(GraphConvolutionLayer, self).__init__()
self.node_field = node_field
if dropout:
self.dropout = nn.Dropout(p=dropout)
else:
self.dropout = 0.
# input layer
self.update_func = NodeUpdateModule(node_field, in_feats, out_feats,
activation)
def forward(self, g, u=ALL, v=ALL):
if self.dropout:
g.apply_nodes(u, apply_node_func=
lambda node: {self.node_field: self.dropout(node[self.node_field])})
if is_all(u) and is_all(v):
g.update_all(fn.copy_src(src=self.node_field, out='m'),
fn.sum(msg='m', out=self.node_field),
self.update_func)
else:
g.send_and_recv(u, v,
fn.copy_src(src=self.node_field, out='m'),
fn.sum(msg='m', out=self.node_field),
self.update_func)
return g
"""Torch modules for graph related softmax."""
# pylint: disable= no-member, arguments-differ
import torch as th
from torch import nn
from ... import function as fn
from ...utils import get_ndata_name
__all__ = ['EdgeSoftmax']
class EdgeSoftmax(nn.Module):
r"""Apply softmax over signals of incoming edges.
For a node :math:`i`, edgesoftmax is an operation of computing
.. math::
a_{ij} = \frac{\exp(z_{ij})}{\sum_{j\in\mathcal{N}(i)}\exp(z_{ij})}
where :math:`z_{ij}` is a signal of edge :math:`j\rightarrow i`, also
called logits in the context of softmax. :math:`\mathcal{N}(i)` is
the set of nodes that have an edge to :math:`i`.
An example of using edgesoftmax is in
`Graph Attention Network <https://arxiv.org/pdf/1710.10903.pdf>`__ where
the attention weights are computed with such an edgesoftmax operation.
"""
def __init__(self):
super(EdgeSoftmax, self).__init__()
# compute the softmax
self._logits_name = "_logits"
self._max_logits_name = "_max_logits"
self._normalizer_name = "_norm"
def forward(self, logits, graph):
r"""Compute edge softmax.
Parameters
----------
logits : torch.Tensor
The input edge feature
graph : DGLGraph
The graph.
Returns
-------
Unnormalized scores : torch.Tensor
This part gives :math:`\exp(z_{ij})`'s
Normalizer : torch.Tensor
This part gives :math:`\sum_{j\in\mathcal{N}(i)}\exp(z_{ij})`
Notes
-----
* Input shape: :math:`(N, *, 1)` where * means any number of additional
dimensions, :math:`N` is the number of edges.
* Unnormalized scores shape: :math:`(N, *, 1)` where all but the last
dimension are the same shape as the input.
* Normalizer shape: :math:`(M, *, 1)` where :math:`M` is the number of
nodes and all but the first and the last dimensions are the same as
the input.
Note that this computation is still one step away from getting real softmax
results. The last step can be proceeded as follows:
>>> import dgl.function as fn
>>>
>>> scores, normalizer = EdgeSoftmax(...).forward(logits, graph)
>>> graph.edata['a'] = scores
>>> graph.ndata['normalizer'] = normalizer
>>> graph.apply_edges(lambda edges : {'a' : edges.data['a'] / edges.dst['normalizer']})
We left this last step to users as depending on the particular use case,
this step can be combined with other computation at once.
"""
self._logits_name = get_ndata_name(graph, self._logits_name)
self._max_logits_name = get_ndata_name(graph, self._max_logits_name)
self._normalizer_name = get_ndata_name(graph, self._normalizer_name)
graph.edata[self._logits_name] = logits
# compute the softmax
graph.update_all(fn.copy_edge(self._logits_name, self._logits_name),
fn.max(self._logits_name, self._max_logits_name))
# minus the max and exp
graph.apply_edges(
lambda edges: {self._logits_name : th.exp(edges.data[self._logits_name] -
edges.dst[self._max_logits_name])})
# compute normalizer
graph.update_all(fn.copy_edge(self._logits_name, self._logits_name),
fn.sum(self._logits_name, self._normalizer_name))
return graph.edata.pop(self._logits_name), graph.ndata.pop(self._normalizer_name)
def __repr__(self):
return 'EdgeSoftmax()'
...@@ -462,3 +462,24 @@ def reorder_index(idx, order): ...@@ -462,3 +462,24 @@ def reorder_index(idx, order):
def is_iterable(obj): def is_iterable(obj):
"""Return true if the object is an iterable.""" """Return true if the object is an iterable."""
return isinstance(obj, Iterable) return isinstance(obj, Iterable)
def get_ndata_name(g, name):
"""Return a node data name that does not exist in the given graph.
The given name is directly returned if it does not exist in the given graph.
Parameters
----------
g : DGLGraph
The graph.
name : str
The proposed name.
Returns
-------
str
The node data name that does not exist.
"""
while name in g.ndata:
name += '_'
return name
...@@ -5,11 +5,11 @@ ...@@ -5,11 +5,11 @@
*/ */
#include <dgl/sampler.h> #include <dgl/sampler.h>
#include <dmlc/omp.h>
#include <dgl/immutable_graph.h> #include <dgl/immutable_graph.h>
#include <algorithm> #include <algorithm>
#include <cstdlib> #include <cstdlib>
#include <cmath> #include <cmath>
#include <dmlc/omp.h>
#ifdef _MSC_VER #ifdef _MSC_VER
// rand in MS compiler works well in multi-threading. // rand in MS compiler works well in multi-threading.
...@@ -322,7 +322,7 @@ NodeFlow ConstructNodeFlow(std::vector<dgl_id_t> neighbor_list, ...@@ -322,7 +322,7 @@ NodeFlow ConstructNodeFlow(std::vector<dgl_id_t> neighbor_list,
size_t pos = neigh_pos->at(i).pos; size_t pos = neigh_pos->at(i).pos;
CHECK_LE(pos, neighbor_list.size()); CHECK_LE(pos, neighbor_list.size());
size_t num_edges = neigh_pos->at(i).num_edges; size_t num_edges = neigh_pos->at(i).num_edges;
if (neighbor_list.empty()) CHECK(num_edges == 0); if (neighbor_list.empty()) CHECK_EQ(num_edges, 0);
// We need to map the Ids of the neighbors to the subgraph. // We need to map the Ids of the neighbors to the subgraph.
auto neigh_it = neighbor_list.begin() + pos; auto neigh_it = neighbor_list.begin() + pos;
...@@ -470,7 +470,7 @@ NodeFlow SampleSubgraph(const ImmutableGraph *graph, ...@@ -470,7 +470,7 @@ NodeFlow SampleSubgraph(const ImmutableGraph *graph,
edge_type, num_edges, num_hops, graph->IsMultigraph()); edge_type, num_edges, num_hops, graph->IsMultigraph());
} }
} // namespace anonymous } // namespace
NodeFlow SamplerOp::NeighborUniformSample(const ImmutableGraph *graph, IdArray seeds, NodeFlow SamplerOp::NeighborUniformSample(const ImmutableGraph *graph, IdArray seeds,
const std::string &edge_type, const std::string &edge_type,
......
...@@ -7,7 +7,7 @@ extension-pkg-whitelist= ...@@ -7,7 +7,7 @@ extension-pkg-whitelist=
# Add files or directories to the blacklist. They should be base names, not # Add files or directories to the blacklist. They should be base names, not
# paths. # paths.
ignore=CVS,_cy2,_cy3,backend,data,nn,contrib ignore=CVS,_cy2,_cy3,backend,data,contrib
# Add files or directories matching the regex patterns to the blacklist. The # Add files or directories matching the regex patterns to the blacklist. The
# regex matches against base names, not paths. # regex matches against base names, not paths.
......
import mxnet as mx
import networkx as nx
import numpy as np
import dgl
import dgl.nn.mxnet as nn
from mxnet import autograd
def check_eq(a, b):
assert a.shape == b.shape
assert mx.nd.sum(a == b).asnumpy() == int(np.prod(list(a.shape)))
def _AXWb(A, X, W, b):
X = mx.nd.dot(X, W.data(X.context))
Y = mx.nd.dot(A, X.reshape(X.shape[0], -1)).reshape(X.shape)
return Y + b.data(X.context)
def test_graph_conv():
g = dgl.DGLGraph(nx.path_graph(3))
adj = g.adjacency_matrix()
ctx = mx.cpu(0)
conv = nn.GraphConv(5, 2, norm=False, bias=True)
conv.initialize(ctx=ctx)
# test#1: basic
h0 = mx.nd.ones((3, 5))
h1 = conv(h0, g)
check_eq(h1, _AXWb(adj, h0, conv.weight, conv.bias))
# test#2: more-dim
h0 = mx.nd.ones((3, 5, 5))
h1 = conv(h0, g)
check_eq(h1, _AXWb(adj, h0, conv.weight, conv.bias))
conv = nn.GraphConv(5, 2)
conv.initialize(ctx=ctx)
# test#3: basic
h0 = mx.nd.ones((3, 5))
h1 = conv(h0, g)
# test#4: basic
h0 = mx.nd.ones((3, 5, 5))
h1 = conv(h0, g)
conv = nn.GraphConv(5, 2)
conv.initialize(ctx=ctx)
with autograd.train_mode():
# test#3: basic
h0 = mx.nd.ones((3, 5))
h1 = conv(h0, g)
# test#4: basic
h0 = mx.nd.ones((3, 5, 5))
h1 = conv(h0, g)
# test repeated features
g.ndata["_gconv_feat"] = 2 * mx.nd.ones((3, 1))
h1 = conv(h0, g)
assert "_gconv_feat" in g.ndata
if __name__ == '__main__':
test_graph_conv()
"""
Placeholder file for framework-specific test
"""
import torch as th
import networkx as nx
import dgl
import dgl.nn.pytorch as nn
from copy import deepcopy
def _AXWb(A, X, W, b):
X = th.matmul(X, W)
Y = th.matmul(A, X.view(X.shape[0], -1)).view_as(X)
return Y + b
def test_graph_conv():
g = dgl.DGLGraph(nx.path_graph(3))
adj = g.adjacency_matrix()
conv = nn.GraphConv(5, 2, norm=False, bias=True)
print(conv)
# test#1: basic
h0 = th.ones((3, 5))
h1 = conv(h0, g)
assert th.allclose(h1, _AXWb(adj, h0, conv.weight, conv.bias))
# test#2: more-dim
h0 = th.ones((3, 5, 5))
h1 = conv(h0, g)
assert th.allclose(h1, _AXWb(adj, h0, conv.weight, conv.bias))
conv = nn.GraphConv(5, 2)
# test#3: basic
h0 = th.ones((3, 5))
h1 = conv(h0, g)
# test#4: basic
h0 = th.ones((3, 5, 5))
h1 = conv(h0, g)
conv = nn.GraphConv(5, 2)
# test#3: basic
h0 = th.ones((3, 5))
h1 = conv(h0, g)
# test#4: basic
h0 = th.ones((3, 5, 5))
h1 = conv(h0, g)
# test rest_parameters
old_weight = deepcopy(conv.weight.data)
conv.reset_parameters()
new_weight = conv.weight.data
assert not th.allclose(old_weight, new_weight)
def uniform_attention(g, shape):
a = th.ones(shape)
target_shape = (g.number_of_edges(),) + (1,) * (len(shape) - 1)
return a / g.in_degrees(g.edges()[1]).view(target_shape).float()
def test_edge_softmax():
edge_softmax = nn.EdgeSoftmax()
print(edge_softmax)
# Basic
g = dgl.DGLGraph(nx.path_graph(3))
edata = th.ones(g.number_of_edges(), 1)
unnormalized, normalizer = edge_softmax(edata, g)
g.edata["a"] = unnormalized
g.ndata["a_sum"] = normalizer
g.apply_edges(lambda edges : {"a": edges.data["a"] / edges.dst["a_sum"]})
assert th.allclose(g.edata["a"], uniform_attention(g, unnormalized.shape))
# Test higher dimension case
edata = th.ones(g.number_of_edges(), 3, 1)
unnormalized, normalizer = edge_softmax(edata, g)
g.edata["a"] = unnormalized
g.ndata["a_sum"] = normalizer
g.apply_edges(lambda edges : {"a": edges.data["a"] / edges.dst["a_sum"]})
assert th.allclose(g.edata["a"], uniform_attention(g, unnormalized.shape))
if __name__ == '__main__':
test_graph_conv()
test_edge_softmax()
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
# cpplint # cpplint
echo 'Checking code style of C++ codes...' echo 'Checking code style of C++ codes...'
python3 third_party/dmlc-core/scripts/lint.py dgl cpp include src python3 third_party/dmlc-core/scripts/lint.py dgl cpp include src || exit 1
# pylint # pylint
echo 'Checking code style of python codes...' echo 'Checking code style of python codes...'
python3 -m pylint --reports=y -v --rcfile=tests/lint/pylintrc python/dgl python3 -m pylint --reports=y -v --rcfile=tests/lint/pylintrc python/dgl || exit 1
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment