Unverified Commit fb3c0709 authored by Zihao Ye's avatar Zihao Ye Committed by GitHub
Browse files

Revert part of #2563 (#2584)

parent 878acdb0
......@@ -27,34 +27,6 @@ else:
__all__ = ['gspmm', 'gsddmm', 'edge_softmax', 'segment_reduce']
_inverse_format = {
'coo': 'coo',
'csr': 'csc',
'csc': 'csr'
}
def _reverse(gidx):
"""Reverse the given graph index while retaining its formats.
Parameters
----------
gidx: HeteroGraphIndex
Return
------
HeteroGraphIndex
"""
g_rev = gidx.reverse()
original_formats_dict = gidx.formats()
original_formats = original_formats_dict['created'] +\
original_formats_dict['not created']
g_rev = g_rev.formats([_inverse_format[fmt] for fmt in original_formats])
return g_rev
def _reduce_grad(grad, shape):
"""Reduce gradient on the broadcast dimension
If there is broadcast in forward pass, gradients need to be reduced on
......@@ -123,7 +95,7 @@ class GSpMM(th.autograd.Function):
gidx, op, reduce_op = ctx.backward_cache
X, Y, argX, argY = ctx.saved_tensors
if op != 'copy_rhs' and ctx.needs_input_grad[3]:
g_rev = _reverse(gidx)
g_rev = gidx.reverse()
if reduce_op == 'sum':
if op in ['mul', 'div']:
dX = gspmm(g_rev, 'mul', 'sum', dZ, _muldiv(op, Y))
......@@ -186,7 +158,7 @@ class GSDDMM(th.autograd.Function):
X, Y = ctx.saved_tensors
if op != 'copy_rhs' and ctx.needs_input_grad[2]:
if lhs_target in ['u', 'v']:
_gidx = gidx if lhs_target == 'v' else _reverse(gidx)
_gidx = gidx if lhs_target == 'v' else gidx.reverse()
if op in ['add', 'sub', 'copy_lhs']:
dX = gspmm(_gidx, 'copy_rhs', 'sum', None, dZ)
else: # mul, div, dot
......@@ -206,7 +178,7 @@ class GSDDMM(th.autograd.Function):
dX = None
if op != 'copy_lhs' and ctx.needs_input_grad[3]:
if rhs_target in ['u', 'v']:
_gidx = gidx if rhs_target == 'v' else _reverse(gidx)
_gidx = gidx if rhs_target == 'v' else gidx.reverse()
if op in ['add', 'sub', 'copy_rhs']:
dY = gspmm(_gidx, 'copy_rhs', 'sum', None, _addsub(op, dZ))
else: # mul, div, dot
......@@ -253,7 +225,7 @@ class EdgeSoftmax(th.autograd.Function):
if not is_all(eids):
gidx = gidx.edge_subgraph([eids], True).graph
if norm_by == 'src':
gidx = _reverse(gidx)
gidx = gidx.reverse()
score_max = _gspmm(gidx, 'copy_rhs', 'max', None, score)[0]
score = th.exp(_gsddmm(gidx, 'sub', score, score_max, 'e', 'v'))
score_sum = _gspmm(gidx, 'copy_rhs', 'sum', None, score)[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment