Unverified Commit 45b610c4 authored by Zihao Ye's avatar Zihao Ye Committed by GitHub
Browse files

fix edge_softmax (#2160)


Co-authored-by: default avatarJinjing Zhou <VoVAllen@users.noreply.github.com>
parent 0cf99be3
...@@ -269,7 +269,7 @@ class EdgeSoftmax(mx.autograd.Function): ...@@ -269,7 +269,7 @@ class EdgeSoftmax(mx.autograd.Function):
def __init__(self, gidx, eids, norm_by): def __init__(self, gidx, eids, norm_by):
super(EdgeSoftmax, self).__init__() super(EdgeSoftmax, self).__init__()
if not is_all(eids): if not is_all(eids):
gidx = gidx.edge_subgraph(eids.astype(gidx.dtype), True) gidx = gidx.edge_subgraph([eids], True).graph
if norm_by == 'src': if norm_by == 'src':
gidx = gidx.reverse() gidx = gidx.reverse()
self.gidx = gidx self.gidx = gidx
......
...@@ -196,7 +196,7 @@ class EdgeSoftmax(th.autograd.Function): ...@@ -196,7 +196,7 @@ class EdgeSoftmax(th.autograd.Function):
# remember to save the graph to backward cache before making it # remember to save the graph to backward cache before making it
# a local variable # a local variable
if not is_all(eids): if not is_all(eids):
gidx = gidx.edge_subgraph(eids.type(gidx.dtype), True) gidx = gidx.edge_subgraph([eids], True).graph
if norm_by == 'src': if norm_by == 'src':
gidx = gidx.reverse() gidx = gidx.reverse()
score_max = _gspmm(gidx, 'copy_rhs', 'max', None, score)[0] score_max = _gspmm(gidx, 'copy_rhs', 'max', None, score)[0]
......
...@@ -225,7 +225,7 @@ def gsddmm(gidx, op, X, Y, lhs_target='u', rhs_target='v'): ...@@ -225,7 +225,7 @@ def gsddmm(gidx, op, X, Y, lhs_target='u', rhs_target='v'):
def edge_softmax_real(gidx, score, eids=ALL, norm_by='dst'): def edge_softmax_real(gidx, score, eids=ALL, norm_by='dst'):
if not is_all(eids): if not is_all(eids):
gidx = gidx.edge_subgraph(tf.cast(eids, gidx.dtype), True) gidx = gidx.edge_subgraph([eids], True).graph
if norm_by == 'src': if norm_by == 'src':
gidx = gidx.reverse() gidx = gidx.reverse()
score_max = _gspmm(gidx, 'copy_rhs', 'max', None, score)[0] score_max = _gspmm(gidx, 'copy_rhs', 'max', None, score)[0]
......
"""dgl edge_softmax operator module.""" """dgl edge_softmax operator module."""
from ..backend import edge_softmax as edge_softmax_internal from ..backend import edge_softmax as edge_softmax_internal
from ..base import ALL from ..backend import astype
from ..base import ALL, is_all
__all__ = ['edge_softmax'] __all__ = ['edge_softmax']
...@@ -103,5 +104,7 @@ def edge_softmax(graph, logits, eids=ALL, norm_by='dst'): ...@@ -103,5 +104,7 @@ def edge_softmax(graph, logits, eids=ALL, norm_by='dst'):
[1.0000], [1.0000],
[0.5000]]) [0.5000]])
""" """
if not is_all(eids):
eids = astype(eids, graph.idtype)
return edge_softmax_internal(graph._graph, logits, return edge_softmax_internal(graph._graph, logits,
eids=eids, norm_by=norm_by) eids=eids, norm_by=norm_by)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment