Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
742d79a7
Unverified
Commit
742d79a7
authored
Aug 06, 2019
by
Zihao Ye
Committed by
GitHub
Aug 06, 2019
Browse files
upd (#741)
parent
5d3f470b
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
124 additions
and
40 deletions
+124
-40
docs/source/api/python/nn.mxnet.rst
docs/source/api/python/nn.mxnet.rst
+2
-4
docs/source/api/python/nn.pytorch.rst
docs/source/api/python/nn.pytorch.rst
+1
-5
python/dgl/nn/mxnet/softmax.py
python/dgl/nn/mxnet/softmax.py
+61
-15
python/dgl/nn/pytorch/softmax.py
python/dgl/nn/pytorch/softmax.py
+60
-16
No files found.
docs/source/api/python/nn.mxnet.rst
View file @
742d79a7
...
@@ -17,12 +17,10 @@ dgl.nn.mxnet.glob
...
@@ -17,12 +17,10 @@ dgl.nn.mxnet.glob
.. automodule:: dgl.nn.mxnet.glob
.. automodule:: dgl.nn.mxnet.glob
:members:
:members:
:show-inheritance:
dgl.nn.mxnet.softmax
dgl.nn.mxnet.softmax
--------------------
--------------------
.. automodule:: dgl.nn.mxnet.softmax
.. automodule:: dgl.nn.mxnet.softmax
:members: edge_softmax
.. autoclass:: dgl.nn.mxnet.softmax.EdgeSoftmax
:members: forward
:show-inheritance:
docs/source/api/python/nn.pytorch.rst
View file @
742d79a7
...
@@ -14,7 +14,6 @@ dgl.nn.pytorch.conv
...
@@ -14,7 +14,6 @@ dgl.nn.pytorch.conv
dgl.nn.pytorch.glob
dgl.nn.pytorch.glob
-------------------
-------------------
.. automodule:: dgl.nn.pytorch.glob
.. automodule:: dgl.nn.pytorch.glob
.. autoclass:: dgl.nn.pytorch.glob.SumPooling
.. autoclass:: dgl.nn.pytorch.glob.SumPooling
...
@@ -53,7 +52,4 @@ dgl.nn.pytorch.softmax
...
@@ -53,7 +52,4 @@ dgl.nn.pytorch.softmax
----------------------
----------------------
.. automodule:: dgl.nn.pytorch.softmax
.. automodule:: dgl.nn.pytorch.softmax
:members: edge_softmax
.. autoclass:: dgl.nn.pytorch.softmax.EdgeSoftmax
:members: forward
:show-inheritance:
python/dgl/nn/mxnet/softmax.py
View file @
742d79a7
...
@@ -32,6 +32,9 @@ class EdgeSoftmax(mx.autograd.Function):
...
@@ -32,6 +32,9 @@ class EdgeSoftmax(mx.autograd.Function):
"""Forward function.
"""Forward function.
Pseudo-code:
Pseudo-code:
.. code:: python
score = dgl.EData(g, score)
score = dgl.EData(g, score)
score_max = score.dst_max() # of type dgl.NData
score_max = score.dst_max() # of type dgl.NData
score = score - score_max # edge_sub_dst, ret dgl.EData
score = score - score_max # edge_sub_dst, ret dgl.EData
...
@@ -54,6 +57,9 @@ class EdgeSoftmax(mx.autograd.Function):
...
@@ -54,6 +57,9 @@ class EdgeSoftmax(mx.autograd.Function):
"""Backward function.
"""Backward function.
Pseudo-code:
Pseudo-code:
.. code:: python
g, out = ctx.backward_cache
g, out = ctx.backward_cache
grad_out = dgl.EData(g, grad_out)
grad_out = dgl.EData(g, grad_out)
out = dgl.EData(g, out)
out = dgl.EData(g, out)
...
@@ -75,6 +81,19 @@ class EdgeSoftmax(mx.autograd.Function):
...
@@ -75,6 +81,19 @@ class EdgeSoftmax(mx.autograd.Function):
def
edge_softmax
(
graph
,
logits
):
def
edge_softmax
(
graph
,
logits
):
r
"""Compute edge softmax.
r
"""Compute edge softmax.
For a node :math:`i`, edge softmax is an operation of computing
.. math::
a_{ij} = \frac{\exp(z_{ij})}{\sum_{j\in\mathcal{N}(i)}\exp(z_{ij})}
where :math:`z_{ij}` is a signal of edge :math:`j\rightarrow i`, also
called logits in the context of softmax. :math:`\mathcal{N}(i)` is
the set of nodes that have an edge to :math:`i`.
An example of using edge softmax is in
`Graph Attention Network <https://arxiv.org/pdf/1710.10903.pdf>`__ where
the attention weights are computed with such an edge softmax operation.
Parameters
Parameters
----------
----------
graph : DGLGraph
graph : DGLGraph
...
@@ -95,8 +114,35 @@ def edge_softmax(graph, logits):
...
@@ -95,8 +114,35 @@ def edge_softmax(graph, logits):
Examples
Examples
--------
--------
>>> import dgl.function as fn
>>> from dgl.nn.mxnet.softmax import edge_softmax
>>> attention = EdgeSoftmax(logits, graph)
>>> import dgl
>>> from mxnet import nd
Create a :code:`DGLGraph` object and initialize its edge features.
>>> g = dgl.DGLGraph()
>>> g.add_nodes(3)
>>> g.add_edges([0, 0, 0, 1, 1, 2], [0, 1, 2, 1, 2, 2])
>>> edata = nd.ones((6, 1))
>>> edata
[[1.]
[1.]
[1.]
[1.]
[1.]
[1.]]
<NDArray 6x1 @cpu(0)>
Apply edge softmax on g:
>>> edge_softmax(g, edata)
[[1. ]
[0.5 ]
[0.33333334]
[0.5 ]
[0.33333334]
[0.33333334]]
<NDArray 6x1 @cpu(0)>
"""
"""
softmax_op
=
EdgeSoftmax
(
graph
)
softmax_op
=
EdgeSoftmax
(
graph
)
return
softmax_op
(
logits
)
return
softmax_op
(
logits
)
python/dgl/nn/pytorch/softmax.py
View file @
742d79a7
...
@@ -29,6 +29,9 @@ class EdgeSoftmax(th.autograd.Function):
...
@@ -29,6 +29,9 @@ class EdgeSoftmax(th.autograd.Function):
"""Forward function.
"""Forward function.
Pseudo-code:
Pseudo-code:
.. code:: python
score = dgl.EData(g, score)
score = dgl.EData(g, score)
score_max = score.dst_max() # of type dgl.NData
score_max = score.dst_max() # of type dgl.NData
score = score - score_max # edge_sub_dst, ret dgl.EData
score = score - score_max # edge_sub_dst, ret dgl.EData
...
@@ -55,6 +58,9 @@ class EdgeSoftmax(th.autograd.Function):
...
@@ -55,6 +58,9 @@ class EdgeSoftmax(th.autograd.Function):
"""Backward function.
"""Backward function.
Pseudo-code:
Pseudo-code:
.. code:: python
g, out = ctx.backward_cache
g, out = ctx.backward_cache
grad_out = dgl.EData(g, grad_out)
grad_out = dgl.EData(g, grad_out)
out = dgl.EData(g, out)
out = dgl.EData(g, out)
...
@@ -79,6 +85,19 @@ class EdgeSoftmax(th.autograd.Function):
...
@@ -79,6 +85,19 @@ class EdgeSoftmax(th.autograd.Function):
def
edge_softmax
(
graph
,
logits
):
def
edge_softmax
(
graph
,
logits
):
r
"""Compute edge softmax.
r
"""Compute edge softmax.
For a node :math:`i`, edge softmax is an operation of computing
.. math::
a_{ij} = \frac{\exp(z_{ij})}{\sum_{j\in\mathcal{N}(i)}\exp(z_{ij})}
where :math:`z_{ij}` is a signal of edge :math:`j\rightarrow i`, also
called logits in the context of softmax. :math:`\mathcal{N}(i)` is
the set of nodes that have an edge to :math:`i`.
An example of using edge softmax is in
`Graph Attention Network <https://arxiv.org/pdf/1710.10903.pdf>`__ where
the attention weights are computed with such an edge softmax operation.
Parameters
Parameters
----------
----------
graph : DGLGraph
graph : DGLGraph
...
@@ -99,7 +118,32 @@ def edge_softmax(graph, logits):
...
@@ -99,7 +118,32 @@ def edge_softmax(graph, logits):
Examples
Examples
--------
--------
>>> import dgl.function as fn
>>> from dgl.nn.pytorch.softmax import edge_softmax
>>> attention = EdgeSoftmax(logits, graph)
>>> import dgl
>>> import torch as th
Create a :code:`DGLGraph` object and initialize its edge features.
>>> g = dgl.DGLGraph()
>>> g.add_nodes(3)
>>> g.add_edges([0, 0, 0, 1, 1, 2], [0, 1, 2, 1, 2, 2])
>>> edata = th.ones(6, 1).float()
>>> edata
tensor([[1.],
[1.],
[1.],
[1.],
[1.],
[1.]])
Apply edge softmax on g:
>>> edge_softmax(g, edata)
tensor([[1.0000],
[0.5000],
[0.3333],
[0.5000],
[0.3333],
[0.3333]])
"""
"""
return
EdgeSoftmax
.
apply
(
graph
,
logits
)
return
EdgeSoftmax
.
apply
(
graph
,
logits
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment