Unverified Commit e4948c5c authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

fix regression in #1237 (#1239)

parent c167373e
......@@ -6,6 +6,8 @@ from ...function import TargetCode
from ...base import ALL, is_all
from ... import backend as F
from ... import utils
from ...graph import DGLGraph
from ...heterograph import DGLHeteroGraph
__all__ = ['edge_softmax']
......@@ -49,7 +51,19 @@ class EdgeSoftmax(th.autograd.Function):
n_nodes = g.number_of_nodes()
n_edges = g.number_of_edges()
gidx = g._graph.get_immutable_gidx(utils.to_dgl_context(score.device))
# TODO(BarclayII): this is a temporary fix of memory leakage in PyTorch
# in PR #1139. We should investigate further on what was actually happening
# when implementing EdgeSoftmax with message passing API instead of
# operators.
score_context = utils.to_dgl_context(score.device)
if isinstance(g, DGLGraph):
gidx = g._graph.get_immutable_gidx(score_context)
elif isinstance(g, DGLHeteroGraph):
assert g._graph.number_of_etypes() == 1, \
"EdgeSoftmax only support one edge type"
gidx = g._graph.get_unitgraph(0, score_context)
ctx.backward_cache = n_nodes, n_edges, gidx
#g.update_all(fn.copy_e('s', 'm'), fn.max('m', 'smax'))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment