Unverified Commit 18bfec24 authored by Zihao Ye's avatar Zihao Ye Committed by GitHub
Browse files

[hotfix] Refactor edge softmax module (#1967)

* upd

* upd

* upd

* upd

* upd

* upd

* upd

* upd
parent 3611a66e
...@@ -196,9 +196,3 @@ Sequential ...@@ -196,9 +196,3 @@ Sequential
.. autoclass:: dgl.nn.mxnet.utils.Sequential .. autoclass:: dgl.nn.mxnet.utils.Sequential
:members: :members:
:show-inheritance: :show-inheritance:
Edge Softmax
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: dgl.nn.mxnet.softmax
:members: edge_softmax
...@@ -231,8 +231,3 @@ SegmentedKNNGraph ...@@ -231,8 +231,3 @@ SegmentedKNNGraph
:members: :members:
:show-inheritance: :show-inheritance:
Edge Softmax
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: dgl.nn.pytorch.softmax
:members: edge_softmax
...@@ -106,13 +106,3 @@ GlobalAttentionPooling ...@@ -106,13 +106,3 @@ GlobalAttentionPooling
.. autoclass:: dgl.nn.tensorflow.glob.GlobalAttentionPooling .. autoclass:: dgl.nn.tensorflow.glob.GlobalAttentionPooling
:members: :members:
:show-inheritance: :show-inheritance:
Utility Modules
----------------------------------------
Edge Softmax
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: dgl.nn.tensorflow.softmax
:members: edge_softmax
...@@ -230,6 +230,20 @@ The following is an example showing how GSDDMM works: ...@@ -230,6 +230,20 @@ The following is an example showing how GSDDMM works:
Like GSpMM, GSDDMM operators support both homograph and bipartite graph. Like GSpMM, GSDDMM operators support both homograph and bipartite graph.
Edge Softmax module
-------------------
We also provide framework agnostic edge softmax module which was frequently used in
GNN-like structures, e.g.
`Graph Attention Network <https://arxiv.org/pdf/1710.10903.pdf>`_,
`Transformer <https://papers.nips.cc/paper/7181-attention-is-all-you-need.pdf>`_,
`Capsule <https://arxiv.org/pdf/1710.09829.pdf>`_, etc.
.. autosummary::
:toctree: ../../generated/
edge_softmax
Relation with Message Passing APIs Relation with Message Passing APIs
---------------------------------- ----------------------------------
...@@ -264,9 +278,7 @@ would be dispatched into function calls of operators defined in ``dgl.ops``: ...@@ -264,9 +278,7 @@ would be dispatched into function calls of operators defined in ``dgl.ops``:
It up to user to decide whether to use message-passing APIs or GSpMM/GSDDMM operators, and both It up to user to decide whether to use message-passing APIs or GSpMM/GSDDMM operators, and both
of them have the same efficiency. Programs written in message-passing APIs look more like DGL-style of them have the same efficiency. Programs written in message-passing APIs look more like DGL-style
but in some cases calling GSpMM/GSDDMM operators is more concise (e.g. `edge_softmax but in some cases calling GSpMM/GSDDMM operators is more concise.
<https://github.com/dmlc/dgl/blob/master/python/dgl/nn/pytorch/softmax.py/>`_ function
provided by dgl).
Note that on PyTorch all operators defined in ``dgl.ops`` support higher-order gradients, so as Note that on PyTorch all operators defined in ``dgl.ops`` support higher-order gradients, so as
message passing APIs because they entirely depend on these operators. message passing APIs because they entirely depend on these operators.
......
...@@ -1417,6 +1417,42 @@ def gsddmm(gidx, op, lhs_data, rhs_data, lhs_target='u', rhs_target='v'): ...@@ -1417,6 +1417,42 @@ def gsddmm(gidx, op, lhs_data, rhs_data, lhs_target='u', rhs_target='v'):
""" """
pass pass
def edge_softmax(gidx, logits, eids, norm_by):
r"""Compute edge softmax.
For a node :math:`i`, edge softmax is an operation of computing
.. math::
a_{ij} = \frac{\exp(z_{ij})}{\sum_{j\in\mathcal{N}(i)}\exp(z_{ij})}
where :math:`z_{ij}` is a signal of edge :math:`j\rightarrow i`, also
called logits in the context of softmax. :math:`\mathcal{N}(i)` is
the set of nodes that have an edge to :math:`i`.
By default edge softmax is normalized by destination nodes(i.e. :math:`ij`
are incoming edges of `i` in the formula above). We also support edge
softmax normalized by source nodes(i.e. :math:`ij` are outgoing edges of
`i` in the formula). The previous case correspond to softmax in GAT and
Transformer, and the later case correspond to softmax in Capsule network.
Parameters
----------
gidx : HeteroGraphIndex
The graph to perfor edge softmax on.
logits : torch.Tensor
The input edge feature
eids : torch.Tensor or ALL, optional
Edges on which to apply edge softmax. If ALL, apply edge
softmax on all edges in the graph. Default: ALL.
norm_by : str, could be `src` or `dst`
Normalized by source nodes or destination nodes. Default: `dst`.
Returns
-------
Tensor
Softmax value
"""
############################################################################### ###############################################################################
# Other interfaces # Other interfaces
......
...@@ -2,9 +2,11 @@ import mxnet as mx ...@@ -2,9 +2,11 @@ import mxnet as mx
import numpy as np import numpy as np
from mxnet import nd from mxnet import nd
from ...sparse import _gspmm, _gsddmm from ...sparse import _gspmm, _gsddmm
from ...base import dgl_warning from ...base import dgl_warning, is_all, ALL
from .tensor import asnumpy, copy_to, zerocopy_from_numpy, context, to_backend_ctx from .tensor import asnumpy, copy_to, zerocopy_from_numpy, context, to_backend_ctx
__all__ = ['gspmm', 'gsddmm', 'edge_softmax']
def _scatter_nd(index, src, n_rows): def _scatter_nd(index, src, n_rows):
"""Similar to PyTorch's scatter nd on first dimension.""" """Similar to PyTorch's scatter nd on first dimension."""
...@@ -104,9 +106,6 @@ def _addsub(op, x): ...@@ -104,9 +106,6 @@ def _addsub(op, x):
def _expand(x, shape): def _expand(x, shape):
padding_zeros = len(shape) + 1 - x.ndim
if padding_zeros > 0:
x = x.reshape((x.shape[0],) + (1,) * padding_zeros + x.shape[1:])
return x.broadcast_to((x.shape[0], *shape)) return x.broadcast_to((x.shape[0], *shape))
...@@ -264,3 +263,62 @@ def gsddmm(gidx, op, lhs_data, rhs_data, lhs_target='u', rhs_target='v'): ...@@ -264,3 +263,62 @@ def gsddmm(gidx, op, lhs_data, rhs_data, lhs_target='u', rhs_target='v'):
if rhs_data is None: if rhs_data is None:
rhs_data = nd.zeros((1,), ctx=ctx) rhs_data = nd.zeros((1,), ctx=ctx)
return func(lhs_data, rhs_data) return func(lhs_data, rhs_data)
class EdgeSoftmax(mx.autograd.Function):
def __init__(self, gidx, eids, norm_by):
super(EdgeSoftmax, self).__init__()
if not is_all(eids):
gidx = gidx.edge_subgraph(eids.astype(gidx.dtype), True)
if norm_by == 'src':
gidx = gidx.reverse()
self.gidx = gidx
def forward(self, score):
"""Forward function.
Pseudo-code:
.. code:: python
score = dgl.EData(g, score)
score_max = score.dst_max() # of type dgl.NData
score = score - score_max # edge_sub_dst, ret dgl.EData
score_sum = score.dst_sum() # of type dgl.NData
out = score / score_sum # edge_div_dst, ret dgl.EData
return out.data
"""
gidx = self.gidx
score_max = _gspmm(gidx, 'copy_rhs', 'max', None, score)[0]
score = mx.nd.exp(_gsddmm(gidx, 'sub', score, score_max, 'e', 'v'))
score_sum = _gspmm(gidx, 'copy_rhs', 'sum', None, score)[0]
out = _gsddmm(gidx, 'div', score, score_sum, 'e', 'v')
self.save_for_backward(out)
return out
def backward(self, grad_out):
"""Backward function.
Pseudo-code:
.. code:: python
g, out = ctx.backward_cache
grad_out = dgl.EData(g, grad_out)
out = dgl.EData(g, out)
sds = out * grad_out # type dgl.EData
sds_sum = sds.dst_sum() # type dgl.NData
grad_score = sds - sds * sds_sum # multiple expressions
"""
out, = self.saved_tensors
gidx = self.gidx
sds = out * grad_out
accum = gspmm(gidx, 'copy_rhs', 'sum', None, sds)
grad_score = sds - gsddmm(gidx, 'mul', out, accum, 'e', 'v')
self.save_tensors = None
return grad_score
def edge_softmax(gidx, logits, eids=ALL, norm_by='dst'):
softmax_op = EdgeSoftmax(gidx, eids, norm_by)
return softmax_op(logits)
import torch as th import torch as th
from ...base import is_all, ALL
from ...sparse import _gspmm, _gsddmm from ...sparse import _gspmm, _gsddmm
__all__ = ['gspmm', 'gsddmm'] __all__ = ['gspmm', 'gsddmm', 'edge_softmax']
def _reduce_grad(grad, shape): def _reduce_grad(grad, shape):
...@@ -54,9 +55,6 @@ def _addsub(op, x): ...@@ -54,9 +55,6 @@ def _addsub(op, x):
def _expand(x, shape): def _expand(x, shape):
padding_zeros = len(shape) + 1 - x.ndim
if padding_zeros > 0:
x = x.view((x.shape[0],) + (1,) * padding_zeros + x.shape[1:])
return x.expand(-1, *shape) return x.expand(-1, *shape)
...@@ -179,9 +177,67 @@ class GSDDMM(th.autograd.Function): ...@@ -179,9 +177,67 @@ class GSDDMM(th.autograd.Function):
return None, None, dX, dY, None, None return None, None, dX, dY, None, None
class EdgeSoftmax(th.autograd.Function):
@staticmethod
def forward(ctx, gidx, score, eids, norm_by):
"""Forward function.
Pseudo-code:
.. code:: python
score = dgl.EData(g, score)
score_max = score.dst_max() # of type dgl.NData
score = score - score_max # edge_sub_dst, ret dgl.EData
score_sum = score.dst_sum() # of type dgl.NData
out = score / score_sum # edge_div_dst, ret dgl.EData
return out.data
"""
# remember to save the graph to backward cache before making it
# a local variable
if not is_all(eids):
gidx = gidx.edge_subgraph(eids.type(gidx.dtype), True)
if norm_by == 'src':
gidx = gidx.reverse()
score_max = _gspmm(gidx, 'copy_rhs', 'max', None, score)[0]
score = th.exp(_gsddmm(gidx, 'sub', score, score_max, 'e', 'v'))
score_sum = _gspmm(gidx, 'copy_rhs', 'sum', None, score)[0]
out = _gsddmm(gidx, 'div', score, score_sum, 'e', 'v')
ctx.backward_cache = gidx
ctx.save_for_backward(out)
return out
@staticmethod
def backward(ctx, grad_out):
"""Backward function.
Pseudo-code:
.. code:: python
g, out = ctx.backward_cache
grad_out = dgl.EData(g, grad_out)
out = dgl.EData(g, out)
sds = out * grad_out # type dgl.EData
sds_sum = sds.dst_sum() # type dgl.NData
grad_score = sds - out * sds_sum # multiple expressions
return grad_score.data
"""
gidx = ctx.backward_cache
out, = ctx.saved_tensors
sds = out * grad_out
accum = gspmm(gidx, 'copy_rhs', 'sum', None, sds)
grad_score = sds - gsddmm(gidx, 'mul', out, accum, 'e', 'v')
return None, grad_score, None, None
def gspmm(gidx, op, reduce_op, lhs_data, rhs_data): def gspmm(gidx, op, reduce_op, lhs_data, rhs_data):
return GSpMM.apply(gidx, op, reduce_op, lhs_data, rhs_data) return GSpMM.apply(gidx, op, reduce_op, lhs_data, rhs_data)
def gsddmm(gidx, op, lhs_data, rhs_data, lhs_target='u', rhs_target='v'): def gsddmm(gidx, op, lhs_data, rhs_data, lhs_target='u', rhs_target='v'):
return GSDDMM.apply(gidx, op, lhs_data, rhs_data, lhs_target, rhs_target) return GSDDMM.apply(gidx, op, lhs_data, rhs_data, lhs_target, rhs_target)
def edge_softmax(gidx, logits, eids=ALL, norm_by='dst'):
return EdgeSoftmax.apply(gidx, logits, eids, norm_by)
import tensorflow as tf import tensorflow as tf
import numpy as np import numpy as np
from .tensor import tensor, copy_to, context from .tensor import tensor, copy_to, context
from ...base import is_all, ALL
from ...sparse import _gspmm, _gsddmm from ...sparse import _gspmm, _gsddmm
__all__ = ['gspmm', 'gsddmm', 'edge_softmax']
def _scatter_nd(index, src, n_rows): def _scatter_nd(index, src, n_rows):
assert index.shape == src.shape assert index.shape == src.shape
...@@ -92,9 +95,6 @@ def _addsub(op, x): ...@@ -92,9 +95,6 @@ def _addsub(op, x):
def _expand(x, shape): def _expand(x, shape):
padding_zeros = len(shape) + 1 - x.ndim
if padding_zeros > 0:
x = tf.reshape(x, (x.shape[0],) + (1,) * padding_zeros + x.shape[1:])
return tf.broadcast_to(x, (x.shape[0], *shape)) return tf.broadcast_to(x, (x.shape[0], *shape))
...@@ -221,3 +221,30 @@ def gsddmm(gidx, op, X, Y, lhs_target='u', rhs_target='v'): ...@@ -221,3 +221,30 @@ def gsddmm(gidx, op, X, Y, lhs_target='u', rhs_target='v'):
if Y is None: if Y is None:
Y = tf.zeros(()) Y = tf.zeros(())
return _lambda(X, Y) return _lambda(X, Y)
def edge_softmax_real(gidx, score, eids=ALL, norm_by='dst'):
if not is_all(eids):
gidx = gidx.edge_subgraph(tf.cast(eids, gidx.dtype), True)
if norm_by == 'src':
gidx = gidx.reverse()
score_max = _gspmm(gidx, 'copy_rhs', 'max', None, score)[0]
score = tf.math.exp(_gsddmm(gidx, 'sub', score, score_max, 'e', 'v'))
score_sum = _gspmm(gidx, 'copy_rhs', 'sum', None, score)[0]
out = _gsddmm(gidx, 'div', score, score_sum, 'e', 'v')
def edge_softmax_backward(grad_out):
sds = out * grad_out
accum = gspmm(gidx, 'copy_rhs', 'sum', None, sds)
grad_score = sds - gsddmm(gidx, 'mul', out, accum, 'e', 'v')
return grad_score
return out, edge_softmax_backward
def edge_softmax(gidx, logits, eids=ALL, norm_by='dst'):
@tf.custom_gradient
def _lambda(logits):
return edge_softmax_real(gidx, logits, eids, norm_by)
return _lambda(logits)
...@@ -4,7 +4,7 @@ import mxnet as mx ...@@ -4,7 +4,7 @@ import mxnet as mx
from mxnet.gluon import nn from mxnet.gluon import nn
from .... import function as fn from .... import function as fn
from ..softmax import edge_softmax from ....ops import edge_softmax
from ..utils import normalize from ..utils import normalize
from ....utils import expand_as_pair from ....utils import expand_as_pair
......
...@@ -6,7 +6,7 @@ from mxnet.gluon import nn ...@@ -6,7 +6,7 @@ from mxnet.gluon import nn
from mxnet.gluon.contrib.nn import Identity from mxnet.gluon.contrib.nn import Identity
from .... import function as fn from .... import function as fn
from ..softmax import edge_softmax from ....ops import edge_softmax
from ....utils import expand_as_pair from ....utils import expand_as_pair
#pylint: enable=W0235 #pylint: enable=W0235
......
"""Gluon layer for graph related softmax.""" """Gluon layer for graph related softmax."""
# pylint: disable= no-member, arguments-differ, access-member-before-definition, unpacking-non-sequence # pylint: disable= unused-import
import mxnet as mx from ...ops import edge_softmax
from ... import ops as F
from ...base import ALL, is_all
__all__ = ['edge_softmax']
class EdgeSoftmax(mx.autograd.Function):
r"""Apply softmax over signals of incoming edges.
For a node :math:`i`, edgesoftmax is an operation of computing
.. math::
a_{ij} = \frac{\exp(z_{ij})}{\sum_{j\in\mathcal{N}(i)}\exp(z_{ij})}
where :math:`z_{ij}` is a signal of edge :math:`j\rightarrow i`, also
called logits in the context of softmax. :math:`\mathcal{N}(i)` is
the set of nodes that have an edge to :math:`i`.
An example of using edgesoftmax is in
`Graph Attention Network <https://arxiv.org/pdf/1710.10903.pdf>`__ where
the attention weights are computed with such an edgesoftmax operation.
"""
def __init__(self, g, eids):
super(EdgeSoftmax, self).__init__()
if not is_all(eids):
g = g.edge_subgraph(eids.astype(g.idtype), preserve_nodes=True)
self.g = g
def forward(self, score):
"""Forward function.
Pseudo-code:
.. code:: python
score = dgl.EData(g, score)
score_max = score.dst_max() # of type dgl.NData
score = score - score_max # edge_sub_dst, ret dgl.EData
score_sum = score.dst_sum() # of type dgl.NData
out = score / score_sum # edge_div_dst, ret dgl.EData
return out.data
"""
g = self.g
score_max = F.copy_e_max(g, score)
score = mx.nd.exp(F.e_sub_v(g, score, score_max))
score_sum = F.copy_e_sum(g, score)
out = F.e_div_v(g, score, score_sum)
self.save_for_backward(out)
return out
def backward(self, grad_out):
"""Backward function.
Pseudo-code:
.. code:: python
g, out = ctx.backward_cache
grad_out = dgl.EData(g, grad_out)
out = dgl.EData(g, out)
sds = out * grad_out # type dgl.EData
sds_sum = sds.dst_sum() # type dgl.NData
grad_score = sds - sds * sds_sum # multiple expressions
"""
out, = self.saved_tensors
g = self.g
sds = out * grad_out
accum = F.copy_e_sum(g, sds)
grad_score = sds - F.e_mul_v(g, out, accum)
self.save_tensors = None
return grad_score
def edge_softmax(graph, logits, eids=ALL):
r"""Compute edge softmax.
For a node :math:`i`, edge softmax is an operation of computing
.. math::
a_{ij} = \frac{\exp(z_{ij})}{\sum_{j\in\mathcal{N}(i)}\exp(z_{ij})}
where :math:`z_{ij}` is a signal of edge :math:`j\rightarrow i`, also
called logits in the context of softmax. :math:`\mathcal{N}(i)` is
the set of nodes that have an edge to :math:`i`.
An example of using edge softmax is in
`Graph Attention Network <https://arxiv.org/pdf/1710.10903.pdf>`__ where
the attention weights are computed with such an edge softmax operation.
Parameters
----------
graph : DGLGraph
The graph to perform edge softmax
logits : mxnet.NDArray
The input edge feature
eids : mxnet.NDArray or ALL, optional
Edges on which to apply edge softmax. If ALL, apply edge softmax
on all edges in the graph. Default: ALL.
Returns
-------
Tensor
Softmax value
Notes
-----
* Input shape: :math:`(E, *, 1)` where * means any number of
additional dimensions, :math:`E` equals the length of eids.
If eids is ALL, :math:`E` equals number of edges in the graph.
* Return shape: :math:`(E, *, 1)`
Examples
--------
>>> from dgl.nn.mxnet.softmax import edge_softmax
>>> import dgl
>>> from mxnet import nd
Create a :code:`DGLGraph` object and initialize its edge features.
>>> g = dgl.DGLGraph()
>>> g.add_nodes(3)
>>> g.add_edges([0, 0, 0, 1, 1, 2], [0, 1, 2, 1, 2, 2])
>>> edata = nd.ones((6, 1))
>>> edata
[[1.]
[1.]
[1.]
[1.]
[1.]
[1.]]
<NDArray 6x1 @cpu(0)>
Apply edge softmax on g:
>>> edge_softmax(g, edata)
[[1. ]
[0.5 ]
[0.33333334]
[0.5 ]
[0.33333334]
[0.33333334]]
<NDArray 6x1 @cpu(0)>
Apply edge softmax on first 4 edges of g:
>>> edge_softmax(g, edata, nd.array([0,1,2,3], dtype='int64'))
[[1. ]
[0.5]
[1. ]
[0.5]]
<NDArray 4x1 @cpu(0)>
"""
softmax_op = EdgeSoftmax(graph, eids)
return softmax_op(logits)
...@@ -5,7 +5,7 @@ from torch import nn ...@@ -5,7 +5,7 @@ from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from .... import function as fn from .... import function as fn
from ..softmax import edge_softmax from ....ops import edge_softmax
from ....utils import expand_as_pair from ....utils import expand_as_pair
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
from torch import nn from torch import nn
from .... import function as fn from .... import function as fn
from ..softmax import edge_softmax from ....ops import edge_softmax
from ....utils import expand_as_pair from ....utils import expand_as_pair
......
...@@ -4,7 +4,7 @@ import torch as th ...@@ -4,7 +4,7 @@ import torch as th
from torch import nn from torch import nn
from .... import function as fn from .... import function as fn
from ..softmax import edge_softmax from ....ops import edge_softmax
from ..utils import Identity from ..utils import Identity
from ....utils import expand_as_pair from ....utils import expand_as_pair
......
"""Torch modules for graph related softmax.""" """Torch modules for graph related softmax."""
# pylint: disable= no-member, arguments-differ # pylint: disable= unused-import
import torch as th from ...ops import edge_softmax
from ...base import ALL, is_all
from ... import ops as F
__all__ = ['edge_softmax']
class EdgeSoftmax(th.autograd.Function):
r"""Apply softmax over signals of incoming edges.
For a node :math:`i`, edgesoftmax is an operation of computing
.. math::
a_{ij} = \frac{\exp(z_{ij})}{\sum_{j\in\mathcal{N}(i)}\exp(z_{ij})}
where :math:`z_{ij}` is a signal of edge :math:`j\rightarrow i`, also
called logits in the context of softmax. :math:`\mathcal{N}(i)` is
the set of nodes that have an edge to :math:`i`.
An example of using edgesoftmax is in
`Graph Attention Network <https://arxiv.org/pdf/1710.10903.pdf>`__ where
the attention weights are computed with such an edgesoftmax operation.
"""
@staticmethod
def forward(ctx, g, score, eids):
"""Forward function.
Pseudo-code:
.. code:: python
score = dgl.EData(g, score)
score_max = score.dst_max() # of type dgl.NData
score = score - score_max # edge_sub_dst, ret dgl.EData
score_sum = score.dst_sum() # of type dgl.NData
out = score / score_sum # edge_div_dst, ret dgl.EData
return out.data
"""
# remember to save the graph to backward cache before making it
# a local variable
if not is_all(eids):
g = g.edge_subgraph(eids.type(g.idtype), preserve_nodes=True)
score_max = F.copy_e_max(g, score)
score = th.exp(F.e_sub_v(g, score, score_max))
score_sum = F.copy_e_sum(g, score)
out = F.e_div_v(g, score, score_sum)
ctx.backward_cache = g
ctx.save_for_backward(out)
return out
@staticmethod
def backward(ctx, grad_out):
"""Backward function.
Pseudo-code:
.. code:: python
g, out = ctx.backward_cache
grad_out = dgl.EData(g, grad_out)
out = dgl.EData(g, out)
sds = out * grad_out # type dgl.EData
sds_sum = sds.dst_sum() # type dgl.NData
grad_score = sds - out * sds_sum # multiple expressions
return grad_score.data
"""
g = ctx.backward_cache
out, = ctx.saved_tensors
sds = out * grad_out
accum = F.copy_e_sum(g, sds)
grad_score = sds - F.e_mul_v(g, out, accum)
return None, grad_score, None
def edge_softmax(graph, logits, eids=ALL):
r"""Compute edge softmax.
For a node :math:`i`, edge softmax is an operation of computing
.. math::
a_{ij} = \frac{\exp(z_{ij})}{\sum_{j\in\mathcal{N}(i)}\exp(z_{ij})}
where :math:`z_{ij}` is a signal of edge :math:`j\rightarrow i`, also
called logits in the context of softmax. :math:`\mathcal{N}(i)` is
the set of nodes that have an edge to :math:`i`.
An example of using edge softmax is in
`Graph Attention Network <https://arxiv.org/pdf/1710.10903.pdf>`__ where
the attention weights are computed with such an edge softmax operation.
Parameters
----------
graph : DGLGraph
The graph to perform edge softmax
logits : torch.Tensor
The input edge feature
eids : torch.Tensor or ALL, optional
Edges on which to apply edge softmax. If ALL, apply edge
softmax on all edges in the graph. Default: ALL.
Returns
-------
Tensor
Softmax value
Notes
-----
* Input shape: :math:`(E, *, 1)` where * means any number of
additional dimensions, :math:`E` equals the length of eids.
If eids is ALL, :math:`E` equals number of edges in the graph.
* Return shape: :math:`(E, *, 1)`
Examples
--------
>>> from dgl.nn.pytorch.softmax import edge_softmax
>>> import dgl
>>> import torch as th
Create a :code:`DGLGraph` object and initialize its edge features.
>>> g = dgl.DGLGraph()
>>> g.add_nodes(3)
>>> g.add_edges([0, 0, 0, 1, 1, 2], [0, 1, 2, 1, 2, 2])
>>> edata = th.ones(6, 1).float()
>>> edata
tensor([[1.],
[1.],
[1.],
[1.],
[1.],
[1.]])
Apply edge softmax on g:
>>> edge_softmax(g, edata)
tensor([[1.0000],
[0.5000],
[0.3333],
[0.5000],
[0.3333],
[0.3333]])
Apply edge softmax on first 4 edges of g:
>>> edge_softmax(g, edata[:4], th.Tensor([0,1,2,3]))
tensor([[1.0000],
[0.5000],
[1.0000],
[0.5000]])
"""
return EdgeSoftmax.apply(graph, logits, eids)
...@@ -5,7 +5,7 @@ from tensorflow.keras import layers ...@@ -5,7 +5,7 @@ from tensorflow.keras import layers
import numpy as np import numpy as np
from .... import function as fn from .... import function as fn
from ..softmax import edge_softmax from ....ops import edge_softmax
from ..utils import Identity from ..utils import Identity
# pylint: enable=W0235 # pylint: enable=W0235
......
"""tf modules for graph related softmax.""" """tf modules for graph related softmax."""
# pylint: disable= no-member, arguments-differ # pylint: disable= unused-import
import tensorflow as tf from ...ops import edge_softmax
from ...sparse import _gspmm, _gsddmm
from ...base import ALL, is_all
__all__ = ['edge_softmax']
def edge_softmax_real(graph, score, eids=ALL):
"""Edge Softmax function"""
if not is_all(eids):
graph = graph.edge_subgraph(tf.cast(eids, graph.idtype), preserve_nodes=True)
gidx = graph._graph
score_max = _gspmm(gidx, 'copy_rhs', 'max', None, score)[0]
score = tf.math.exp(_gsddmm(gidx, 'sub', score, score_max, 'e', 'v'))
score_sum = _gspmm(gidx, 'copy_rhs', 'sum', None, score)[0]
out = _gsddmm(gidx, 'div', score, score_sum, 'e', 'v')
def edge_softmax_backward(grad_out):
sds = out * grad_out
accum = _gspmm(gidx, 'copy_rhs', 'sum', None, sds)[0]
grad_score = sds - _gsddmm(gidx, 'mul', out, accum, 'e', 'v')
return grad_score
return out, edge_softmax_backward
def edge_softmax(graph, logits, eids=ALL):
"""Closure for tf.custom_gradient"""
@tf.custom_gradient
def _lambda(logits):
return edge_softmax_real(graph, logits, eids=eids)
return _lambda(logits)
"""dgl operator module.""" """dgl operator module."""
from .spmm import * from .spmm import *
from .sddmm import * from .sddmm import *
from .edge_softmax import *
"""dgl edge_softmax operator module."""
from ..backend import edge_softmax as edge_softmax_internal
from ..base import ALL
__all__ = ['edge_softmax']
def edge_softmax(graph, logits, eids=ALL, norm_by='dst'):
r"""Compute edge softmax.
For a node :math:`i`, edge softmax is an operation of computing
.. math::
a_{ij} = \frac{\exp(z_{ij})}{\sum_{j\in\mathcal{N}(i)}\exp(z_{ij})}
where :math:`z_{ij}` is a signal of edge :math:`j\rightarrow i`, also
called logits in the context of softmax. :math:`\mathcal{N}(i)` is
the set of nodes that have an edge to :math:`i`.
By default edge softmax is normalized by destination nodes(i.e. :math:`ij`
are incoming edges of `i` in the formula above). We also support edge
softmax normalized by source nodes(i.e. :math:`ij` are outgoing edges of
`i` in the formula). The previous case correspond to softmax in GAT and
Transformer, and the later case correspond to softmax in Capsule network.
Parameters
----------
gidx : HeteroGraphIndex
The graph to perfor edge softmax on.
logits : torch.Tensor
The input edge feature
eids : torch.Tensor or ALL, optional
Edges on which to apply edge softmax. If ALL, apply edge
softmax on all edges in the graph. Default: ALL.
norm_by : str, could be `src` or `dst`
Normalized by source nodes or destination nodes. Default: `dst`.
Returns
-------
Tensor
Softmax value
Notes
-----
* Input shape: :math:`(E, *, 1)` where * means any number of
additional dimensions, :math:`E` equals the length of eids.
* Return shape: :math:`(E, *, 1)`
Examples
--------
>>> from dgl.ops import edge_softmax
>>> import dgl
>>> import torch as th
Create a :code:`DGLGraph` object and initialize its edge features.
>>> g = dgl.DGLGraph()
>>> g.add_nodes(3)
>>> g.add_edges([0, 0, 0, 1, 1, 2], [0, 1, 2, 1, 2, 2])
>>> edata = th.ones(6, 1).float()
>>> edata
tensor([[1.],
[1.],
[1.],
[1.],
[1.],
[1.]])
Apply edge softmax on g:
>>> edge_softmax(g, edata)
tensor([[1.0000],
[0.5000],
[0.3333],
[0.5000],
[0.3333],
[0.3333]])
Apply edge softmax on g normalized by source nodes:
>>> edge_softmax(g, edata, norm_by='src')
tensor([[0.3333],
[0.3333],
[0.3333],
[0.5000],
[0.5000],
[1.0000]])
Apply edge softmax on first 4 edges of g:
>>> edge_softmax(g, edata[:4], th.Tensor([0,1,2,3]))
tensor([[1.0000],
[0.5000],
[1.0000],
[0.5000]])
"""
return edge_softmax_internal(graph._graph, logits,
eids=eids, norm_by=norm_by)
from dgl.ops import gspmm, gsddmm from dgl.ops import gspmm, gsddmm, edge_softmax
from test_utils.graph_cases import get_cases
from utils import parametrize_dtype from utils import parametrize_dtype
import dgl import dgl
import random import random
...@@ -6,7 +7,6 @@ import pytest ...@@ -6,7 +7,6 @@ import pytest
import networkx as nx import networkx as nx
import backend as F import backend as F
import numpy as np import numpy as np
from utils import parametrize_dtype
random.seed(42) random.seed(42)
np.random.seed(42) np.random.seed(42)
...@@ -90,6 +90,10 @@ sddmm_shapes = [ ...@@ -90,6 +90,10 @@ sddmm_shapes = [
((1,), (1,)) ((1,), (1,))
] ]
edge_softmax_shapes = [
(1,), (1, 3), (3, 4, 5)
]
@pytest.mark.parametrize('g', graphs) @pytest.mark.parametrize('g', graphs)
@pytest.mark.parametrize('shp', spmm_shapes) @pytest.mark.parametrize('shp', spmm_shapes)
@pytest.mark.parametrize('msg', ['add', 'sub', 'mul', 'div', 'copy_lhs', 'copy_rhs']) @pytest.mark.parametrize('msg', ['add', 'sub', 'mul', 'div', 'copy_lhs', 'copy_rhs'])
...@@ -222,5 +226,36 @@ def test_sddmm(g, shp, lhs_target, rhs_target, msg, idtype): ...@@ -222,5 +226,36 @@ def test_sddmm(g, shp, lhs_target, rhs_target, msg, idtype):
rhs_frame.pop('y') rhs_frame.pop('y')
if 'm' in g.edata: g.edata.pop('m') if 'm' in g.edata: g.edata.pop('m')
@pytest.mark.parametrize('g', get_cases(['clique']))
@pytest.mark.parametrize('norm_by', ['src', 'dst'])
@pytest.mark.parametrize('shp', edge_softmax_shapes)
@parametrize_dtype
def test_edge_softmax(g, norm_by, shp, idtype):
g = g.astype(idtype).to(F.ctx())
edata = F.tensor(np.random.rand(g.number_of_edges(), *shp))
e1 = F.attach_grad(F.clone(edata))
with F.record_grad():
score1 = edge_softmax(g, e1, norm_by=norm_by)
F.backward(F.reduce_sum(score1))
grad_edata = F.grad(e1)
with F.record_grad():
e2 = F.attach_grad(F.clone(edata))
e2_2d = F.reshape(
e2, (g.number_of_src_nodes(), g.number_of_dst_nodes(), *e2.shape[1:]))
if norm_by == 'src':
score2 = F.softmax(e2_2d, 1)
score2 = F.reshape(score2, (-1, *e2.shape[1:]))
if norm_by == 'dst':
score2 = F.softmax(e2_2d, 0)
score2 = F.reshape(score2, (-1, *e2.shape[1:]))
assert F.allclose(score1, score2)
print('forward passed')
F.backward(F.reduce_sum(score2))
assert F.allclose(F.grad(e2), grad_edata)
print('backward passed')
if __name__ == '__main__': if __name__ == '__main__':
test_spmm(F.int32, graphs[0], spmm_shapes[5], 'copy_lhs', 'sum') test_spmm(F.int32, graphs[0], spmm_shapes[5], 'copy_lhs', 'sum')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment