Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
6a4b5ae9
Unverified
Commit
6a4b5ae9
authored
Sep 09, 2019
by
Zihao Ye
Committed by
GitHub
Sep 09, 2019
Browse files
[Feature] Edge softmax on a subset of edges in the graph. (#842)
* upd * add test * fix * upd * merge * hotfix * upd * fix
parent
bcd33e0a
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
105 additions
and
16 deletions
+105
-16
examples/mxnet/gat/utils.py
examples/mxnet/gat/utils.py
+0
-1
python/dgl/nn/mxnet/softmax.py
python/dgl/nn/mxnet/softmax.py
+22
-7
python/dgl/nn/pytorch/softmax.py
python/dgl/nn/pytorch/softmax.py
+21
-7
tests/mxnet/test_nn.py
tests/mxnet/test_nn.py
+31
-0
tests/pytorch/test_nn.py
tests/pytorch/test_nn.py
+31
-1
No files found.
examples/mxnet/gat/utils.py
View file @
6a4b5ae9
import
numpy
as
np
import
numpy
as
np
import
torch
class
EarlyStopping
:
class
EarlyStopping
:
def
__init__
(
self
,
patience
=
10
):
def
__init__
(
self
,
patience
=
10
):
...
...
python/dgl/nn/mxnet/softmax.py
View file @
6a4b5ae9
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
import
mxnet
as
mx
import
mxnet
as
mx
from
...
import
function
as
fn
from
...
import
function
as
fn
from
...base
import
ALL
,
is_all
__all__
=
[
'edge_softmax'
]
__all__
=
[
'edge_softmax'
]
...
@@ -24,8 +25,10 @@ class EdgeSoftmax(mx.autograd.Function):
...
@@ -24,8 +25,10 @@ class EdgeSoftmax(mx.autograd.Function):
the attention weights are computed with such an edgesoftmax operation.
the attention weights are computed with such an edgesoftmax operation.
"""
"""
def
__init__
(
self
,
g
):
def
__init__
(
self
,
g
,
eids
):
super
(
EdgeSoftmax
,
self
).
__init__
()
super
(
EdgeSoftmax
,
self
).
__init__
()
if
not
is_all
(
eids
):
g
=
g
.
edge_subgraph
(
eids
.
astype
(
'int64'
))
self
.
g
=
g
self
.
g
=
g
def
forward
(
self
,
score
):
def
forward
(
self
,
score
):
...
@@ -78,7 +81,7 @@ class EdgeSoftmax(mx.autograd.Function):
...
@@ -78,7 +81,7 @@ class EdgeSoftmax(mx.autograd.Function):
grad_score
=
g
.
edata
[
'grad_score'
]
-
g
.
edata
[
'out'
]
grad_score
=
g
.
edata
[
'grad_score'
]
-
g
.
edata
[
'out'
]
return
grad_score
return
grad_score
def
edge_softmax
(
graph
,
logits
):
def
edge_softmax
(
graph
,
logits
,
eids
=
ALL
):
r
"""Compute edge softmax.
r
"""Compute edge softmax.
For a node :math:`i`, edge softmax is an operation of computing
For a node :math:`i`, edge softmax is an operation of computing
...
@@ -98,8 +101,11 @@ def edge_softmax(graph, logits):
...
@@ -98,8 +101,11 @@ def edge_softmax(graph, logits):
----------
----------
graph : DGLGraph
graph : DGLGraph
The graph to perform edge softmax
The graph to perform edge softmax
logits :
torch.Tensor
logits :
mxnet.NDArray
The input edge feature
The input edge feature
eids : mxnet.NDArray or ALL, optional
Edges on which to apply edge softmax. If ALL, apply edge softmax
on all edges in the graph. Default: ALL.
Returns
Returns
-------
-------
...
@@ -108,9 +114,10 @@ def edge_softmax(graph, logits):
...
@@ -108,9 +114,10 @@ def edge_softmax(graph, logits):
Notes
Notes
-----
-----
* Input shape: :math:`(N, *, 1)` where * means any number of
* Input shape: :math:`(E, *, 1)` where * means any number of
additional dimensions, :math:`N` is the number of edges.
additional dimensions, :math:`E` equals the length of eids.
* Return shape: :math:`(N, *, 1)`
If eids is ALL, :math:`E` equals number of edges in the graph.
* Return shape: :math:`(E, *, 1)`
Examples
Examples
--------
--------
...
@@ -143,6 +150,14 @@ def edge_softmax(graph, logits):
...
@@ -143,6 +150,14 @@ def edge_softmax(graph, logits):
[0.33333334]
[0.33333334]
[0.33333334]]
[0.33333334]]
<NDArray 6x1 @cpu(0)>
<NDArray 6x1 @cpu(0)>
Apply edge softmax on first 4 edges of g:
>>> edge_softmax(g, edata, nd.array([0,1,2,3], dtype='int64'))
[[1. ]
[0.5]
[1. ]
[0.5]]
<NDArray 4x1 @cpu(0)>
"""
"""
softmax_op
=
EdgeSoftmax
(
graph
)
softmax_op
=
EdgeSoftmax
(
graph
,
eids
)
return
softmax_op
(
logits
)
return
softmax_op
(
logits
)
python/dgl/nn/pytorch/softmax.py
View file @
6a4b5ae9
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
import
torch
as
th
import
torch
as
th
from
...
import
function
as
fn
from
...
import
function
as
fn
from
...base
import
ALL
,
is_all
__all__
=
[
'edge_softmax'
]
__all__
=
[
'edge_softmax'
]
...
@@ -25,7 +26,7 @@ class EdgeSoftmax(th.autograd.Function):
...
@@ -25,7 +26,7 @@ class EdgeSoftmax(th.autograd.Function):
"""
"""
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
g
,
score
):
def
forward
(
ctx
,
g
,
score
,
eids
):
"""Forward function.
"""Forward function.
Pseudo-code:
Pseudo-code:
...
@@ -41,6 +42,8 @@ class EdgeSoftmax(th.autograd.Function):
...
@@ -41,6 +42,8 @@ class EdgeSoftmax(th.autograd.Function):
"""
"""
# remember to save the graph to backward cache before making it
# remember to save the graph to backward cache before making it
# a local variable
# a local variable
if
not
is_all
(
eids
):
g
=
g
.
edge_subgraph
(
eids
.
long
())
ctx
.
backward_cache
=
g
ctx
.
backward_cache
=
g
g
=
g
.
local_var
()
g
=
g
.
local_var
()
g
.
edata
[
's'
]
=
score
g
.
edata
[
's'
]
=
score
...
@@ -79,10 +82,10 @@ class EdgeSoftmax(th.autograd.Function):
...
@@ -79,10 +82,10 @@ class EdgeSoftmax(th.autograd.Function):
g
.
update_all
(
fn
.
copy_e
(
'grad_s'
,
'm'
),
fn
.
sum
(
'm'
,
'accum'
))
g
.
update_all
(
fn
.
copy_e
(
'grad_s'
,
'm'
),
fn
.
sum
(
'm'
,
'accum'
))
g
.
apply_edges
(
fn
.
e_mul_v
(
'out'
,
'accum'
,
'out'
))
g
.
apply_edges
(
fn
.
e_mul_v
(
'out'
,
'accum'
,
'out'
))
grad_score
=
g
.
edata
[
'grad_s'
]
-
g
.
edata
[
'out'
]
grad_score
=
g
.
edata
[
'grad_s'
]
-
g
.
edata
[
'out'
]
return
None
,
grad_score
return
None
,
grad_score
,
None
def
edge_softmax
(
graph
,
logits
):
def
edge_softmax
(
graph
,
logits
,
eids
=
ALL
):
r
"""Compute edge softmax.
r
"""Compute edge softmax.
For a node :math:`i`, edge softmax is an operation of computing
For a node :math:`i`, edge softmax is an operation of computing
...
@@ -104,6 +107,9 @@ def edge_softmax(graph, logits):
...
@@ -104,6 +107,9 @@ def edge_softmax(graph, logits):
The graph to perform edge softmax
The graph to perform edge softmax
logits : torch.Tensor
logits : torch.Tensor
The input edge feature
The input edge feature
eids : torch.Tensor or ALL, optional
Edges on which to apply edge softmax. If ALL, apply edge
softmax on all edges in the graph. Default: ALL.
Returns
Returns
-------
-------
...
@@ -112,9 +118,10 @@ def edge_softmax(graph, logits):
...
@@ -112,9 +118,10 @@ def edge_softmax(graph, logits):
Notes
Notes
-----
-----
* Input shape: :math:`(N, *, 1)` where * means any number of
* Input shape: :math:`(E, *, 1)` where * means any number of
additional dimensions, :math:`N` is the number of edges.
additional dimensions, :math:`E` equals the length of eids.
* Return shape: :math:`(N, *, 1)`
If eids is ALL, :math:`E` equals number of edges in the graph.
* Return shape: :math:`(E, *, 1)`
Examples
Examples
--------
--------
...
@@ -145,5 +152,12 @@ def edge_softmax(graph, logits):
...
@@ -145,5 +152,12 @@ def edge_softmax(graph, logits):
[0.5000],
[0.5000],
[0.3333],
[0.3333],
[0.3333]])
[0.3333]])
Apply edge softmax on first 4 edges of g:
>>> edge_softmax(g, edata[:4], th.Tensor([0,1,2,3]))
tensor([[1.0000],
[0.5000],
[1.0000],
[0.5000]])
"""
"""
return
EdgeSoftmax
.
apply
(
graph
,
logits
)
return
EdgeSoftmax
.
apply
(
graph
,
logits
,
eids
)
tests/mxnet/test_nn.py
View file @
6a4b5ae9
...
@@ -223,6 +223,36 @@ def test_edge_softmax():
...
@@ -223,6 +223,36 @@ def test_edge_softmax():
assert
np
.
allclose
(
a
.
asnumpy
(),
uniform_attention
(
g
,
a
.
shape
).
asnumpy
(),
assert
np
.
allclose
(
a
.
asnumpy
(),
uniform_attention
(
g
,
a
.
shape
).
asnumpy
(),
1e-4
,
1e-4
)
1e-4
,
1e-4
)
def
test_partial_edge_softmax
():
g
=
dgl
.
DGLGraph
()
g
.
add_nodes
(
30
)
# build a complete graph
for
i
in
range
(
30
):
for
j
in
range
(
30
):
g
.
add_edge
(
i
,
j
)
score
=
F
.
randn
((
300
,
1
))
score
.
attach_grad
()
grad
=
F
.
randn
((
300
,
1
))
import
numpy
as
np
eids
=
np
.
random
.
choice
(
900
,
300
,
replace
=
False
).
astype
(
'int64'
)
eids
=
F
.
zerocopy_from_numpy
(
eids
)
# compute partial edge softmax
with
mx
.
autograd
.
record
():
y_1
=
nn
.
edge_softmax
(
g
,
score
,
eids
)
y_1
.
backward
(
grad
)
grad_1
=
score
.
grad
# compute edge softmax on edge subgraph
subg
=
g
.
edge_subgraph
(
eids
)
with
mx
.
autograd
.
record
():
y_2
=
nn
.
edge_softmax
(
subg
,
score
)
y_2
.
backward
(
grad
)
grad_2
=
score
.
grad
assert
F
.
allclose
(
y_1
,
y_2
)
assert
F
.
allclose
(
grad_1
,
grad_2
)
def
test_rgcn
():
def
test_rgcn
():
ctx
=
F
.
ctx
()
ctx
=
F
.
ctx
()
etype
=
[]
etype
=
[]
...
@@ -277,6 +307,7 @@ def test_rgcn():
...
@@ -277,6 +307,7 @@ def test_rgcn():
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
test_graph_conv
()
test_graph_conv
()
test_edge_softmax
()
test_edge_softmax
()
test_partial_edge_softmax
()
test_set2set
()
test_set2set
()
test_glob_att_pool
()
test_glob_att_pool
()
test_simple_pool
()
test_simple_pool
()
...
...
tests/pytorch/test_nn.py
View file @
6a4b5ae9
...
@@ -319,7 +319,36 @@ def test_edge_softmax():
...
@@ -319,7 +319,36 @@ def test_edge_softmax():
assert
len
(
g
.
ndata
)
==
0
assert
len
(
g
.
ndata
)
==
0
assert
len
(
g
.
edata
)
==
2
assert
len
(
g
.
edata
)
==
2
assert
F
.
allclose
(
a1
.
grad
,
a2
.
grad
,
rtol
=
1e-4
,
atol
=
1e-4
)
# Follow tolerance in unittest backend
assert
F
.
allclose
(
a1
.
grad
,
a2
.
grad
,
rtol
=
1e-4
,
atol
=
1e-4
)
# Follow tolerance in unittest backend
def
test_partial_edge_softmax
():
g
=
dgl
.
DGLGraph
()
g
.
add_nodes
(
30
)
# build a complete graph
for
i
in
range
(
30
):
for
j
in
range
(
30
):
g
.
add_edge
(
i
,
j
)
score
=
F
.
randn
((
300
,
1
))
score
.
requires_grad_
()
grad
=
F
.
randn
((
300
,
1
))
import
numpy
as
np
eids
=
np
.
random
.
choice
(
900
,
300
,
replace
=
False
).
astype
(
'int64'
)
eids
=
F
.
zerocopy_from_numpy
(
eids
)
# compute partial edge softmax
y_1
=
nn
.
edge_softmax
(
g
,
score
,
eids
)
y_1
.
backward
(
grad
)
grad_1
=
score
.
grad
score
.
grad
.
zero_
()
# compute edge softmax on edge subgraph
subg
=
g
.
edge_subgraph
(
eids
)
y_2
=
nn
.
edge_softmax
(
subg
,
score
)
y_2
.
backward
(
grad
)
grad_2
=
score
.
grad
score
.
grad
.
zero_
()
assert
F
.
allclose
(
y_1
,
y_2
)
assert
F
.
allclose
(
grad_1
,
grad_2
)
def
test_rgcn
():
def
test_rgcn
():
ctx
=
F
.
ctx
()
ctx
=
F
.
ctx
()
etype
=
[]
etype
=
[]
...
@@ -570,6 +599,7 @@ def test_dense_cheb_conv():
...
@@ -570,6 +599,7 @@ def test_dense_cheb_conv():
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
test_graph_conv
()
test_graph_conv
()
test_edge_softmax
()
test_edge_softmax
()
test_partial_edge_softmax
()
test_set2set
()
test_set2set
()
test_glob_att_pool
()
test_glob_att_pool
()
test_simple_pool
()
test_simple_pool
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment