Unverified Commit 2efdaa5d authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Bug] Revert clearing backward cache for retain_graph flag (#4249)


Co-authored-by: default avatarMinjie Wang <wmjlyjemaine@gmail.com>
parent 5c76e47f
......@@ -127,7 +127,6 @@ class GSpMM(th.autograd.Function):
@custom_bwd
def backward(ctx, dZ):
gidx, op, reduce_op, X_shape, Y_shape, dtype, device, reduce_last = ctx.backward_cache
ctx.backward_cache = None
X, Y, argX, argY = ctx.saved_tensors
if op != 'copy_rhs' and ctx.needs_input_grad[3]:
g_rev = gidx.reverse()
......@@ -207,7 +206,6 @@ class GSpMM_hetero(th.autograd.Function):
@custom_bwd
def backward(ctx, *dZ):
gidx, op, reduce_op, X_shape, Y_shape, dtype, device, reduce_last, X_len = ctx.backward_cache
ctx.backward_cache = None
num_ntypes = gidx.number_of_ntypes()
feats = ctx.saved_tensors[:-(4 * num_ntypes)]
argX = ctx.saved_tensors[-(4 * num_ntypes):-(3 * num_ntypes)]
......@@ -305,7 +303,6 @@ class GSDDMM(th.autograd.Function):
@custom_bwd
def backward(ctx, dZ):
gidx, op, lhs_target, rhs_target, X_shape, Y_shape = ctx.backward_cache
ctx.backward_cache = None
X, Y = ctx.saved_tensors
if op != 'copy_rhs' and ctx.needs_input_grad[2]:
if lhs_target in ['u', 'v']:
......@@ -373,7 +370,6 @@ class GSDDMM_hetero(th.autograd.Function):
# TODO(Israt): Implement the complete backward operator
def backward(ctx, *dZ):
gidx, op, lhs_target, rhs_target, X_shape, Y_shape, X_len = ctx.backward_cache
ctx.backward_cache = None
feats = ctx.saved_tensors
X, Y = feats[:X_len], feats[X_len:]
if op != 'copy_rhs' and any([x is not None for x in X]):
......@@ -484,8 +480,6 @@ class EdgeSoftmax(th.autograd.Function):
return grad_score.data
"""
gidx = ctx.backward_cache
# See https://github.com/dmlc/dgl/pull/3386
ctx.backward_cache = None
out, = ctx.saved_tensors
sds = out * grad_out
#Note: Now _edge_softmax_backward op only supports CPU
......@@ -554,8 +548,6 @@ class EdgeSoftmax_hetero(th.autograd.Function):
return grad_score.data
"""
gidx = ctx.backward_cache
# See https://github.com/dmlc/dgl/pull/3386
ctx.backward_cache = None
u_len = gidx.number_of_ntypes()
e_len = gidx.number_of_etypes()
lhs = [None] * u_len
......@@ -582,8 +574,6 @@ class SegmentReduce(th.autograd.Function):
@custom_bwd
def backward(ctx, dy):
op = ctx.backward_cache
# See https://github.com/dmlc/dgl/pull/3386
ctx.backward_cache = None
arg, offsets = ctx.saved_tensors
m = offsets[-1].item()
if op == 'sum':
......@@ -630,7 +620,6 @@ class CSRMM(th.autograd.Function):
def backward(ctx, dnrows, dncols, dC_indptr, dC_indices, dC_eids, dC_weights):
# Only the last argument is meaningful.
gidxA, gidxB, gidxC = ctx.backward_cache
ctx.backward_cache = None
A_weights, B_weights = ctx.saved_tensors
dgidxA, dA_weights = csrmm(
gidxC, dC_weights, gidxB.reverse(), B_weights, gidxA.number_of_ntypes())
......@@ -657,7 +646,6 @@ class CSRSum(th.autograd.Function):
def backward(ctx, dnrows, dncols, dC_indptr, dC_indices, dC_eids, dC_weights):
# Only the last argument is meaningful.
gidxs, gidxC = ctx.backward_cache
ctx.backward_cache = None
return (None,) + tuple(csrmask(gidxC, dC_weights, gidx) for gidx in gidxs)
......@@ -670,7 +658,6 @@ class CSRMask(th.autograd.Function):
@staticmethod
def backward(ctx, dB_weights):
gidxA, gidxB = ctx.backward_cache
ctx.backward_cache = None
return None, csrmask(gidxB, dB_weights, gidxA), None
......
......@@ -418,8 +418,6 @@ class BinaryReduce(th.autograd.Function):
def backward(ctx, grad_out):
reducer, binary_op, graph, lhs, rhs, lhs_map, rhs_map, out_map, \
feat_shape, degs = ctx.backward_cache
# See https://github.com/dmlc/dgl/pull/3386
ctx.backward_cache = None
lhs_data, rhs_data, out_data = ctx.saved_tensors
lhs_data_nd = zerocopy_to_dgl_ndarray(lhs_data)
rhs_data_nd = zerocopy_to_dgl_ndarray(rhs_data)
......@@ -497,8 +495,6 @@ class CopyReduce(th.autograd.Function):
@staticmethod
def backward(ctx, grad_out):
reducer, graph, target, in_map, out_map, degs = ctx.backward_cache
# See https://github.com/dmlc/dgl/pull/3386
ctx.backward_cache = None
in_data, out_data = ctx.saved_tensors
in_data_nd = zerocopy_to_dgl_ndarray(in_data)
out_data_nd = zerocopy_to_dgl_ndarray(out_data)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment