Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
45b610c4
Unverified
Commit
45b610c4
authored
Sep 09, 2020
by
Zihao Ye
Committed by
GitHub
Sep 09, 2020
Browse files
fix edge_softmax (#2160)
Co-authored-by:
Jinjing Zhou
<
VoVAllen@users.noreply.github.com
>
parent
0cf99be3
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
7 additions
and
4 deletions
+7
-4
python/dgl/backend/mxnet/sparse.py
python/dgl/backend/mxnet/sparse.py
+1
-1
python/dgl/backend/pytorch/sparse.py
python/dgl/backend/pytorch/sparse.py
+1
-1
python/dgl/backend/tensorflow/sparse.py
python/dgl/backend/tensorflow/sparse.py
+1
-1
python/dgl/ops/edge_softmax.py
python/dgl/ops/edge_softmax.py
+4
-1
No files found.
python/dgl/backend/mxnet/sparse.py
View file @
45b610c4
...
...
@@ -269,7 +269,7 @@ class EdgeSoftmax(mx.autograd.Function):
def
__init__
(
self
,
gidx
,
eids
,
norm_by
):
super
(
EdgeSoftmax
,
self
).
__init__
()
if
not
is_all
(
eids
):
gidx
=
gidx
.
edge_subgraph
(
eids
.
astype
(
gidx
.
dtype
)
,
True
)
gidx
=
gidx
.
edge_subgraph
(
[
eids
]
,
True
)
.
graph
if
norm_by
==
'src'
:
gidx
=
gidx
.
reverse
()
self
.
gidx
=
gidx
...
...
python/dgl/backend/pytorch/sparse.py
View file @
45b610c4
...
...
@@ -196,7 +196,7 @@ class EdgeSoftmax(th.autograd.Function):
# remember to save the graph to backward cache before making it
# a local variable
if
not
is_all
(
eids
):
gidx
=
gidx
.
edge_subgraph
(
eids
.
type
(
gidx
.
dtype
)
,
True
)
gidx
=
gidx
.
edge_subgraph
(
[
eids
]
,
True
)
.
graph
if
norm_by
==
'src'
:
gidx
=
gidx
.
reverse
()
score_max
=
_gspmm
(
gidx
,
'copy_rhs'
,
'max'
,
None
,
score
)[
0
]
...
...
python/dgl/backend/tensorflow/sparse.py
View file @
45b610c4
...
...
@@ -225,7 +225,7 @@ def gsddmm(gidx, op, X, Y, lhs_target='u', rhs_target='v'):
def
edge_softmax_real
(
gidx
,
score
,
eids
=
ALL
,
norm_by
=
'dst'
):
if
not
is_all
(
eids
):
gidx
=
gidx
.
edge_subgraph
(
tf
.
cast
(
eids
,
gidx
.
dtype
)
,
True
)
gidx
=
gidx
.
edge_subgraph
(
[
eids
]
,
True
)
.
graph
if
norm_by
==
'src'
:
gidx
=
gidx
.
reverse
()
score_max
=
_gspmm
(
gidx
,
'copy_rhs'
,
'max'
,
None
,
score
)[
0
]
...
...
python/dgl/ops/edge_softmax.py
View file @
45b610c4
"""dgl edge_softmax operator module."""
from
..backend
import
edge_softmax
as
edge_softmax_internal
from
..base
import
ALL
from
..backend
import
astype
from
..base
import
ALL
,
is_all
__all__
=
[
'edge_softmax'
]
...
...
@@ -103,5 +104,7 @@ def edge_softmax(graph, logits, eids=ALL, norm_by='dst'):
[1.0000],
[0.5000]])
"""
if
not
is_all
(
eids
):
eids
=
astype
(
eids
,
graph
.
idtype
)
return
edge_softmax_internal
(
graph
.
_graph
,
logits
,
eids
=
eids
,
norm_by
=
norm_by
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment