Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
e4948c5c
Unverified
Commit
e4948c5c
authored
Feb 03, 2020
by
Quan (Andy) Gan
Committed by
GitHub
Feb 03, 2020
Browse files
fix regression in #1237 (#1239)
parent
c167373e
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
15 additions
and
1 deletion
+15
-1
python/dgl/nn/pytorch/softmax.py
python/dgl/nn/pytorch/softmax.py
+15
-1
No files found.
python/dgl/nn/pytorch/softmax.py
View file @
e4948c5c
...
@@ -6,6 +6,8 @@ from ...function import TargetCode
...
@@ -6,6 +6,8 @@ from ...function import TargetCode
from
...base
import
ALL
,
is_all
from
...base
import
ALL
,
is_all
from
...
import
backend
as
F
from
...
import
backend
as
F
from
...
import
utils
from
...
import
utils
from
...graph
import
DGLGraph
from
...heterograph
import
DGLHeteroGraph
__all__
=
[
'edge_softmax'
]
__all__
=
[
'edge_softmax'
]
...
@@ -49,7 +51,19 @@ class EdgeSoftmax(th.autograd.Function):
...
@@ -49,7 +51,19 @@ class EdgeSoftmax(th.autograd.Function):
n_nodes
=
g
.
number_of_nodes
()
n_nodes
=
g
.
number_of_nodes
()
n_edges
=
g
.
number_of_edges
()
n_edges
=
g
.
number_of_edges
()
gidx
=
g
.
_graph
.
get_immutable_gidx
(
utils
.
to_dgl_context
(
score
.
device
))
# TODO(BarclayII): this is a temporary fix of memory leakage in PyTorch
# in PR #1139. We should investigate further on what was actually happening
# when implementing EdgeSoftmax with message passing API instead of
# operators.
score_context
=
utils
.
to_dgl_context
(
score
.
device
)
if
isinstance
(
g
,
DGLGraph
):
gidx
=
g
.
_graph
.
get_immutable_gidx
(
score_context
)
elif
isinstance
(
g
,
DGLHeteroGraph
):
assert
g
.
_graph
.
number_of_etypes
()
==
1
,
\
"EdgeSoftmax only support one edge type"
gidx
=
g
.
_graph
.
get_unitgraph
(
0
,
score_context
)
ctx
.
backward_cache
=
n_nodes
,
n_edges
,
gidx
ctx
.
backward_cache
=
n_nodes
,
n_edges
,
gidx
#g.update_all(fn.copy_e('s', 'm'), fn.max('m', 'smax'))
#g.update_all(fn.copy_e('s', 'm'), fn.max('m', 'smax'))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment