Unverified Commit 0855d255 authored by Tong He's avatar Tong He Committed by GitHub
Browse files

[NN] Support scalar edge weight for GraphConv, SAGEConv and GINConv (#2557)

* add edge weight in forward

* fix lint

* fix

* fix

* address comments

* add utils

* add util to normalize in gcn way

* fix lint

* add unittest

* fix lint

* fix docstring

* fix docstring

* address comments

* improve notation consistence

* use preferred fn
parent 8900450d
......@@ -17,6 +17,13 @@ GraphConv
:members: weight, bias, forward, reset_parameters
:show-inheritance:
EdgeWeightNorm
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: dgl.nn.pytorch.conv.EdgeWeightNorm
:members: forward
:show-inheritance:
RelGraphConv
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......
......@@ -235,15 +235,15 @@ def src_mul_edge(src, edge, out):
----------
src : str
The source feature field.
dst : str
The destination feature field.
edge : str
The edge feature field.
out : str
The output message field.
Examples
--------
>>> import dgl
>>> message_func = dgl.function.src_mul_edge('h', 'h', 'm')
>>> message_func = dgl.function.src_mul_edge('h', 'e', 'm')
"""
return getattr(sys.modules[__name__], "u_mul_e")(src, edge, out)
......
......@@ -8,7 +8,7 @@ from .edgeconv import EdgeConv
from .gatconv import GATConv
from .ginconv import GINConv
from .gmmconv import GMMConv
from .graphconv import GraphConv
from .graphconv import GraphConv, EdgeWeightNorm
from .nnconv import NNConv
from .relgraphconv import RelGraphConv
from .sageconv import SAGEConv
......@@ -22,7 +22,7 @@ from .atomicconv import AtomicConv
from .cfconv import CFConv
from .dotgatconv import DotGatConv
__all__ = ['GraphConv', 'GATConv', 'TAGConv', 'RelGraphConv', 'SAGEConv',
__all__ = ['GraphConv', 'EdgeWeightNorm', 'GATConv', 'TAGConv', 'RelGraphConv', 'SAGEConv',
'SGConv', 'APPNPConv', 'GINConv', 'GatedGraphConv', 'GMMConv',
'ChebConv', 'AGNNConv', 'NNConv', 'DenseGraphConv', 'DenseSAGEConv',
'DenseChebConv', 'EdgeConv', 'AtomicConv', 'CFConv', 'DotGatConv']
......@@ -20,6 +20,16 @@ class GINConv(nn.Module):
\mathrm{aggregate}\left(\left\{h_j^{l}, j\in\mathcal{N}(i)
\right\}\right)\right)
If a weight tensor on each edge is provided, the weighted graph convolution is defined as:
.. math::
h_i^{(l+1)} = f_\Theta \left((1 + \epsilon) h_i^{l} +
\mathrm{aggregate}\left(\left\{e_{ji} h_j^{l}, j\in\mathcal{N}(i)
\right\}\right)\right)
where :math:`e_{ji}` is the weight on the edge from node :math:`j` to node :math:`i`.
Please make sure that `e_{ji}` is broadcastable with `h_j^{l}`.
Parameters
----------
apply_func : callable activation function/layer or None
......@@ -80,7 +90,7 @@ class GINConv(nn.Module):
else:
self.register_buffer('eps', th.FloatTensor([init_eps]))
def forward(self, graph, feat):
def forward(self, graph, feat, edge_weight=None):
r"""
Description
......@@ -98,6 +108,9 @@ class GINConv(nn.Module):
:math:`(N_{in}, D_{in})` and :math:`(N_{out}, D_{in})`.
If ``apply_func`` is not None, :math:`D_{in}` should
fit the input dimensionality requirement of ``apply_func``.
edge_weight : torch.Tensor, optional
Optional tensor on the edge. If given, the convolution will weight
with regard to the message.
Returns
-------
......@@ -108,9 +121,15 @@ class GINConv(nn.Module):
as input dimensionality.
"""
with graph.local_scope():
aggregate_fn = fn.copy_src('h', 'm')
if edge_weight is not None:
assert edge_weight.shape[0] == graph.number_of_edges()
graph.edata['_edge_weight'] = edge_weight
aggregate_fn = fn.u_mul_e('h', '_edge_weight', 'm')
feat_src, feat_dst = expand_as_pair(feat, graph)
graph.srcdata['h'] = feat_src
graph.update_all(fn.copy_u('h', 'm'), self._reducer('m', 'neigh'))
graph.update_all(aggregate_fn, self._reducer('m', 'neigh'))
rst = (1 + self.eps) * feat_dst + graph.dstdata['neigh']
if self.apply_func is not None:
rst = self.apply_func(rst)
......
......@@ -7,6 +7,135 @@ from torch.nn import init
from .... import function as fn
from ....base import DGLError
from ....utils import expand_as_pair
from ....transform import reverse
from ....convert import block_to_graph
from ....heterograph import DGLBlock
class EdgeWeightNorm(nn.Module):
r"""
Description
-----------
This module normalizes positive scalar edge weights on a graph
following the form in `GCN <https://arxiv.org/abs/1609.02907>`__.
Mathematically, setting ``norm='both'`` yields the following normalization term:
.. math:
c_{ji} = (\sqrt{\sum_{k\in\mathcal{N}(j)}e_{jk}}\sqrt{\sum_{k\in\mathcal{N}(i)}e_{ki}})
And, setting ``norm='right'`` yields the following normalization term:
.. math:
c_{ji} = (\sum_{k\in\mathcal{N}(i)}}e_{ki})
where :math:`e_{ji}` is the scalar weight on the edge from node :math:`j` to node :math:`i`.
The module returns the normalized weight :math:`e_{ji} / c_{ji}`.
Parameters
----------
norm : str, optional
The normalizer as specified above. Default is `'both'`.
eps : float, optional
A small offset value in the denominator. Default is 0.
Examples
--------
>>> import dgl
>>> import numpy as np
>>> import torch as th
>>> from dgl.nn import EdgeWeightNorm, GraphConv
>>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))
>>> g = dgl.add_self_loop(g)
>>> feat = th.ones(6, 10)
>>> edge_weight = th.tensor([0.5, 0.6, 0.4, 0.7, 0.9, 0.1, 1, 1, 1, 1, 1, 1])
>>> norm = EdgeWeightNorm(norm='both')
>>> norm_edge_weight = norm(g, edge_weight)
>>> conv = GraphConv(10, 2, norm='none', weight=True, bias=True)
>>> res = conv(g, feat, edge_weight=norm_edge_weight)
>>> print(res)
tensor([[-1.1849, -0.7525],
[-1.3514, -0.8582],
[-1.2384, -0.7865],
[-1.9949, -1.2669],
[-1.3658, -0.8674],
[-0.8323, -0.5286]], grad_fn=<AddBackward0>)
"""
def __init__(self, norm='both', eps=0.):
super(EdgeWeightNorm, self).__init__()
self._norm = norm
self._eps = eps
def forward(self, graph, edge_weight):
r"""
Description
-----------
Compute normalized edge weight for the GCN model.
Parameters
----------
graph : DGLGraph
The graph.
edge_weight : torch.Tensor
Unnormalized scalar weights on the edges.
The shape is expected to be :math:`(|E|)`.
Returns
-------
torch.Tensor
The normalized edge weight.
Raises
------
DGLError
Case 1:
The edge weight is multi-dimensional. Currently this module
only supports a scalar weight on each edge.
Case 2:
The edge weight has non-positive values with ``norm='both'``.
This will trigger square root and division by a non-positive number.
"""
with graph.local_scope():
if isinstance(graph, DGLBlock):
graph = block_to_graph(graph)
if len(edge_weight.shape) > 1:
raise DGLError('Currently the normalization is only defined '
'on scalar edge weight. Please customize the '
'normalization for your high-dimensional weights.')
if self._norm == 'both' and th.any(edge_weight <= 0).item():
raise DGLError('Non-positive edge weight detected with `norm="both"`. '
'This leads to square root of zero or negative values.')
dev = graph.device
graph.srcdata['_src_out_w'] = th.ones((graph.number_of_src_nodes())).float().to(dev)
graph.dstdata['_dst_in_w'] = th.ones((graph.number_of_dst_nodes())).float().to(dev)
graph.edata['_edge_w'] = edge_weight
if self._norm == 'both':
reversed_g = reverse(graph)
reversed_g.edata['_edge_w'] = edge_weight
reversed_g.update_all(fn.copy_edge('_edge_w', 'm'), fn.sum('m', 'out_weight'))
degs = reversed_g.dstdata['out_weight'] + self._eps
norm = th.pow(degs, -0.5)
graph.srcdata['_src_out_w'] = norm
if self._norm != 'none':
graph.update_all(fn.copy_edge('_edge_w', 'm'), fn.sum('m', 'in_weight'))
degs = graph.dstdata['in_weight'] + self._eps
if self._norm == 'both':
norm = th.pow(degs, -0.5)
else:
norm = 1.0 / degs
graph.dstdata['_dst_in_w'] = norm
graph.apply_edges(lambda e: {'_norm_edge_weights': e.src['_src_out_w'] * \
e.dst['_dst_in_w'] * \
e.data['_edge_w']})
return graph.edata['_norm_edge_weights']
# pylint: disable=W0235
class GraphConv(nn.Module):
......@@ -18,13 +147,25 @@ class GraphConv(nn.Module):
and mathematically is defined as follows:
.. math::
h_i^{(l+1)} = \sigma(b^{(l)} + \sum_{j\in\mathcal{N}(i)}\frac{1}{c_{ij}}h_j^{(l)}W^{(l)})
h_i^{(l+1)} = \sigma(b^{(l)} + \sum_{j\in\mathcal{N}(i)}\frac{1}{c_{ji}}h_j^{(l)}W^{(l)})
where :math:`\mathcal{N}(i)` is the set of neighbors of node :math:`i`,
:math:`c_{ij}` is the product of the square root of node degrees
(i.e., :math:`c_{ij} = \sqrt{|\mathcal{N}(i)|}\sqrt{|\mathcal{N}(j)|}`),
:math:`c_{ji}` is the product of the square root of node degrees
(i.e., :math:`c_{ji} = \sqrt{|\mathcal{N}(j)|}\sqrt{|\mathcal{N}(i)|}`),
and :math:`\sigma` is an activation function.
If a weight tensor on each edge is provided, the weighted graph convolution is defined as:
.. math::
h_i^{(l+1)} = \sigma(b^{(l)} + \sum_{j\in\mathcal{N}(i)}\frac{e_{ji}}{c_{ji}}h_j^{(l)}W^{(l)})
where :math:`e_{ji}` is the scalar weight on the edge from node :math:`j` to node :math:`i`.
This is NOT equivalent to the weighted graph convolutional network formulation in the paper.
To customize the normalization term :math:`c_{ji}`, one can first set ``norm='none'`` for
the model, and send the pre-normalized :math:`e_{ji}` to the forward computation. We provide
:class:`~dgl.nn.pytorch.EdgeWeightNorm` to normalize scalar edge weight following the GCN paper.
Parameters
----------
in_feats : int
......@@ -35,7 +176,7 @@ class GraphConv(nn.Module):
How to apply the normalizer. If is `'right'`, divide the aggregated messages
by each node's in-degrees, which is equivalent to averaging the received messages.
If is `'none'`, no normalization is applied. Default is `'both'`,
where the :math:`c_{ij}` in the paper is applied.
where the :math:`c_{ji}` in the paper is applied.
weight : bool, optional
If True, apply a linear layer. Otherwise, aggregating the messages
without a weight matrix.
......@@ -185,7 +326,7 @@ class GraphConv(nn.Module):
"""
self._allow_zero_in_degree = set_value
def forward(self, graph, feat, weight=None):
def forward(self, graph, feat, weight=None, edge_weight=None):
r"""
Description
......@@ -205,6 +346,9 @@ class GraphConv(nn.Module):
:math:`(N_{out}, D_{in_{dst}})`.
weight : torch.Tensor, optional
Optional external weight tensor.
edge_weight : torch.Tensor, optional
Optional tensor on the edge. If given, the convolution will weight
with regard to the message.
Returns
-------
......@@ -243,6 +387,11 @@ class GraphConv(nn.Module):
'the issue. Setting ``allow_zero_in_degree`` '
'to be `True` when constructing this module will '
'suppress the check and let the code run.')
aggregate_fn = fn.copy_src('h', 'm')
if edge_weight is not None:
assert edge_weight.shape[0] == graph.number_of_edges()
graph.edata['_edge_weight'] = edge_weight
aggregate_fn = fn.u_mul_e('h', '_edge_weight', 'm')
# (BarclayII) For RGCN on heterogeneous graphs we need to support GCN on bipartite.
feat_src, feat_dst = expand_as_pair(feat, graph)
......@@ -266,14 +415,12 @@ class GraphConv(nn.Module):
if weight is not None:
feat_src = th.matmul(feat_src, weight)
graph.srcdata['h'] = feat_src
graph.update_all(fn.copy_src(src='h', out='m'),
fn.sum(msg='m', out='h'))
graph.update_all(aggregate_fn, fn.sum(msg='m', out='h'))
rst = graph.dstdata['h']
else:
# aggregate first then mult W
graph.srcdata['h'] = feat_src
graph.update_all(fn.copy_src(src='h', out='m'),
fn.sum(msg='m', out='h'))
graph.update_all(aggregate_fn, fn.sum(msg='m', out='h'))
rst = graph.dstdata['h']
if weight is not None:
rst = th.matmul(rst, weight)
......
......@@ -25,6 +25,15 @@ class SAGEConv(nn.Module):
h_{i}^{(l+1)} &= \mathrm{norm}(h_{i}^{l})
If a weight tensor on each edge is provided, the aggregation becomes:
.. math::
h_{\mathcal{N}(i)}^{(l+1)} &= \mathrm{aggregate}
\left(\{e_{ji} h_{j}^{l}, \forall j \in \mathcal{N}(i) \}\right)
where :math:`e_{ji}` is the scalar weight on the edge from node :math:`j` to node :math:`i`.
Please make sure that `e_{ji}` is broadcastable with `h_j^{l}`.
Parameters
----------
in_feats : int, or pair of ints
......@@ -147,7 +156,7 @@ class SAGEConv(nn.Module):
_, (rst, _) = self.lstm(m, h)
return {'neigh': rst.squeeze(0)}
def forward(self, graph, feat):
def forward(self, graph, feat, edge_weight=None):
r"""
Description
......@@ -164,6 +173,9 @@ class SAGEConv(nn.Module):
where :math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.
If a pair of torch.Tensor is given, the pair must contain two tensors of shape
:math:`(N_{in}, D_{in_{src}})` and :math:`(N_{out}, D_{in_{dst}})`.
edge_weight : torch.Tensor, optional
Optional tensor on the edge. If given, the convolution will weight
with regard to the message.
Returns
-------
......@@ -179,6 +191,11 @@ class SAGEConv(nn.Module):
feat_src = feat_dst = self.feat_drop(feat)
if graph.is_block:
feat_dst = feat_src[:graph.number_of_dst_nodes()]
aggregate_fn = fn.copy_src('h', 'm')
if edge_weight is not None:
assert edge_weight.shape[0] == graph.number_of_edges()
graph.edata['_edge_weight'] = edge_weight
aggregate_fn = fn.u_mul_e('h', '_edge_weight', 'm')
h_self = feat_dst
......@@ -189,23 +206,23 @@ class SAGEConv(nn.Module):
if self._aggre_type == 'mean':
graph.srcdata['h'] = feat_src
graph.update_all(fn.copy_src('h', 'm'), fn.mean('m', 'neigh'))
graph.update_all(aggregate_fn, fn.mean('m', 'neigh'))
h_neigh = graph.dstdata['neigh']
elif self._aggre_type == 'gcn':
check_eq_shape(feat)
graph.srcdata['h'] = feat_src
graph.dstdata['h'] = feat_dst # same as above if homogeneous
graph.update_all(fn.copy_src('h', 'm'), fn.sum('m', 'neigh'))
graph.update_all(aggregate_fn, fn.sum('m', 'neigh'))
# divide in_degrees
degs = graph.in_degrees().to(feat_dst)
h_neigh = (graph.dstdata['neigh'] + graph.dstdata['h']) / (degs.unsqueeze(-1) + 1)
elif self._aggre_type == 'pool':
graph.srcdata['h'] = F.relu(self.fc_pool(feat_src))
graph.update_all(fn.copy_src('h', 'm'), fn.max('m', 'neigh'))
graph.update_all(aggregate_fn, fn.max('m', 'neigh'))
h_neigh = graph.dstdata['neigh']
elif self._aggre_type == 'lstm':
graph.srcdata['h'] = feat_src
graph.update_all(fn.copy_src('h', 'm'), self._lstm_reducer)
graph.update_all(aggregate_fn, self._lstm_reducer)
h_neigh = graph.dstdata['neigh']
else:
raise KeyError('Aggregator type {} not recognized.'.format(self._aggre_type))
......
......@@ -88,6 +88,45 @@ def test_graph_conv(idtype, g, norm, weight, bias):
h_out = conv(g, h, weight=ext_w)
assert h_out.shape == (ndst, 2)
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['has_scalar_e_feature'], exclude=['zero-degree', 'dglgraph']))
@pytest.mark.parametrize('norm', ['none', 'both', 'right'])
@pytest.mark.parametrize('weight', [True, False])
@pytest.mark.parametrize('bias', [True, False])
def test_graph_conv_e_weight(idtype, g, norm, weight, bias):
g = g.astype(idtype).to(F.ctx())
conv = nn.GraphConv(5, 2, norm=norm, weight=weight, bias=bias).to(F.ctx())
ext_w = F.randn((5, 2)).to(F.ctx())
nsrc = g.number_of_src_nodes()
ndst = g.number_of_dst_nodes()
h = F.randn((nsrc, 5)).to(F.ctx())
e_w = g.edata['scalar_w']
if weight:
h_out = conv(g, h, edge_weight=e_w)
else:
h_out = conv(g, h, weight=ext_w, edge_weight=e_w)
assert h_out.shape == (ndst, 2)
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['has_scalar_e_feature'], exclude=['zero-degree', 'dglgraph']))
@pytest.mark.parametrize('norm', ['none', 'both', 'right'])
@pytest.mark.parametrize('weight', [True, False])
@pytest.mark.parametrize('bias', [True, False])
def test_graph_conv_e_weight_norm(idtype, g, norm, weight, bias):
g = g.astype(idtype).to(F.ctx())
conv = nn.GraphConv(5, 2, norm=norm, weight=weight, bias=bias).to(F.ctx())
ext_w = F.randn((5, 2)).to(F.ctx())
nsrc = g.number_of_src_nodes()
ndst = g.number_of_dst_nodes()
h = F.randn((nsrc, 5)).to(F.ctx())
edgenorm = nn.EdgeWeightNorm(norm=norm)
norm_weight = edgenorm(g, g.edata['scalar_w'])
if weight:
h_out = conv(g, h, edge_weight=norm_weight)
else:
h_out = conv(g, h, weight=ext_w, edge_weight=norm_weight)
assert h_out.shape == (ndst, 2)
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree', 'dglgraph']))
@pytest.mark.parametrize('norm', ['none', 'both', 'right'])
......@@ -959,6 +998,8 @@ def test_hetero_conv(agg, idtype):
if __name__ == '__main__':
test_graph_conv()
test_graph_conv_e_weight()
test_graph_conv_e_weight_norm()
test_set2set()
test_glob_att_pool()
test_simple_pool()
......
......@@ -56,6 +56,14 @@ def graph1():
g.edata['w'] = F.copy_to(F.randn((g.number_of_edges(), 3)), F.cpu())
return g
@register_case(['homo', 'has_scalar_e_feature'])
def graph1():
g = dgl.graph(([0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 3, 4, 6, 6, 7, 8, 9],
[4, 5, 1, 2, 4, 7, 9, 8 ,6, 4, 1, 0, 1, 0, 2, 3, 5]), device=F.cpu())
g.ndata['h'] = F.copy_to(F.randn((g.number_of_nodes(), 2)), F.cpu())
g.edata['scalar_w'] = F.copy_to(F.abs(F.randn((g.number_of_edges(),))), F.cpu())
return g
@register_case(['hetero', 'has_feature'])
def heterograph0():
g = dgl.heterograph({
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment