Unverified Commit ff94ee80 authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

[BugFix] Avoid Memory Leak Issue in PyTorch Backend (#3386)



* try to avoid memory leak

* try to avoid memory leak

* avoid memory leak with no hope

* Revert "avoid memory leak with no hope"

This reverts commit c77befe9479f46758e744642f66dd209b50eef7d.

* no message

* Update sparse.py

* Update tensor.py
Co-authored-by: default avatarJinjing Zhou <VoVAllen@users.noreply.github.com>
parent f7039418
...@@ -144,6 +144,7 @@ class GSpMM(th.autograd.Function): ...@@ -144,6 +144,7 @@ class GSpMM(th.autograd.Function):
@custom_bwd @custom_bwd
def backward(ctx, dZ): def backward(ctx, dZ):
gidx, op, reduce_op, X_shape, Y_shape, dtype, device, reduce_last = ctx.backward_cache gidx, op, reduce_op, X_shape, Y_shape, dtype, device, reduce_last = ctx.backward_cache
ctx.backward_cache = None
X, Y, argX, argY = ctx.saved_tensors X, Y, argX, argY = ctx.saved_tensors
if op != 'copy_rhs' and ctx.needs_input_grad[3]: if op != 'copy_rhs' and ctx.needs_input_grad[3]:
g_rev = gidx.reverse() g_rev = gidx.reverse()
...@@ -224,6 +225,7 @@ class GSpMM_hetero(th.autograd.Function): ...@@ -224,6 +225,7 @@ class GSpMM_hetero(th.autograd.Function):
@custom_bwd @custom_bwd
def backward(ctx, *dZ): def backward(ctx, *dZ):
gidx, op, reduce_op, X_shape, Y_shape, dtype, device, reduce_last, X_len = ctx.backward_cache gidx, op, reduce_op, X_shape, Y_shape, dtype, device, reduce_last, X_len = ctx.backward_cache
ctx.backward_cache = None
feats = ctx.saved_tensors[:-2] feats = ctx.saved_tensors[:-2]
argX = ctx.saved_tensors[-2] argX = ctx.saved_tensors[-2]
argY = ctx.saved_tensors[-1] argY = ctx.saved_tensors[-1]
...@@ -298,6 +300,7 @@ class GSDDMM(th.autograd.Function): ...@@ -298,6 +300,7 @@ class GSDDMM(th.autograd.Function):
@custom_bwd @custom_bwd
def backward(ctx, dZ): def backward(ctx, dZ):
gidx, op, lhs_target, rhs_target, X_shape, Y_shape = ctx.backward_cache gidx, op, lhs_target, rhs_target, X_shape, Y_shape = ctx.backward_cache
ctx.backward_cache = None
X, Y = ctx.saved_tensors X, Y = ctx.saved_tensors
if op != 'copy_rhs' and ctx.needs_input_grad[2]: if op != 'copy_rhs' and ctx.needs_input_grad[2]:
if lhs_target in ['u', 'v']: if lhs_target in ['u', 'v']:
...@@ -365,6 +368,7 @@ class GSDDMM_hetero(th.autograd.Function): ...@@ -365,6 +368,7 @@ class GSDDMM_hetero(th.autograd.Function):
# TODO(Israt): Implement the complete backward operator # TODO(Israt): Implement the complete backward operator
def backward(ctx, *dZ): def backward(ctx, *dZ):
gidx, op, lhs_target, rhs_target, X_shape, Y_shape, X_len = ctx.backward_cache gidx, op, lhs_target, rhs_target, X_shape, Y_shape, X_len = ctx.backward_cache
ctx.backward_cache = None
feats = ctx.saved_tensors feats = ctx.saved_tensors
X, Y = feats[:X_len], feats[X_len:] X, Y = feats[:X_len], feats[X_len:]
if op != 'copy_rhs' and any([x is not None for x in X]): if op != 'copy_rhs' and any([x is not None for x in X]):
...@@ -469,6 +473,8 @@ class EdgeSoftmax(th.autograd.Function): ...@@ -469,6 +473,8 @@ class EdgeSoftmax(th.autograd.Function):
return grad_score.data return grad_score.data
""" """
gidx = ctx.backward_cache gidx = ctx.backward_cache
# See https://github.com/dmlc/dgl/pull/3386
ctx.backward_cache = None
out, = ctx.saved_tensors out, = ctx.saved_tensors
sds = out * grad_out sds = out * grad_out
accum = gspmm(gidx, 'copy_rhs', 'sum', None, sds) accum = gspmm(gidx, 'copy_rhs', 'sum', None, sds)
...@@ -489,6 +495,8 @@ class SegmentReduce(th.autograd.Function): ...@@ -489,6 +495,8 @@ class SegmentReduce(th.autograd.Function):
@custom_bwd @custom_bwd
def backward(ctx, dy): def backward(ctx, dy):
op = ctx.backward_cache op = ctx.backward_cache
# See https://github.com/dmlc/dgl/pull/3386
ctx.backward_cache = None
arg, offsets = ctx.saved_tensors arg, offsets = ctx.saved_tensors
m = offsets[-1].item() m = offsets[-1].item()
if op == 'sum': if op == 'sum':
...@@ -535,6 +543,7 @@ class CSRMM(th.autograd.Function): ...@@ -535,6 +543,7 @@ class CSRMM(th.autograd.Function):
def backward(ctx, dnrows, dncols, dC_indptr, dC_indices, dC_eids, dC_weights): def backward(ctx, dnrows, dncols, dC_indptr, dC_indices, dC_eids, dC_weights):
# Only the last argument is meaningful. # Only the last argument is meaningful.
gidxA, gidxB, gidxC = ctx.backward_cache gidxA, gidxB, gidxC = ctx.backward_cache
ctx.backward_cache = None
A_weights, B_weights = ctx.saved_tensors A_weights, B_weights = ctx.saved_tensors
dgidxA, dA_weights = csrmm( dgidxA, dA_weights = csrmm(
gidxC, dC_weights, gidxB.reverse(), B_weights, gidxA.number_of_ntypes()) gidxC, dC_weights, gidxB.reverse(), B_weights, gidxA.number_of_ntypes())
...@@ -561,6 +570,7 @@ class CSRSum(th.autograd.Function): ...@@ -561,6 +570,7 @@ class CSRSum(th.autograd.Function):
def backward(ctx, dnrows, dncols, dC_indptr, dC_indices, dC_eids, dC_weights): def backward(ctx, dnrows, dncols, dC_indptr, dC_indices, dC_eids, dC_weights):
# Only the last argument is meaningful. # Only the last argument is meaningful.
gidxs, gidxC = ctx.backward_cache gidxs, gidxC = ctx.backward_cache
ctx.backward_cache = None
return (None,) + tuple(csrmask(gidxC, dC_weights, gidx) for gidx in gidxs) return (None,) + tuple(csrmask(gidxC, dC_weights, gidx) for gidx in gidxs)
...@@ -573,6 +583,7 @@ class CSRMask(th.autograd.Function): ...@@ -573,6 +583,7 @@ class CSRMask(th.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, dB_weights): def backward(ctx, dB_weights):
gidxA, gidxB = ctx.backward_cache gidxA, gidxB = ctx.backward_cache
ctx.backward_cache = None
return None, csrmask(gidxB, dB_weights, gidxA), None return None, csrmask(gidxB, dB_weights, gidxA), None
......
...@@ -401,6 +401,8 @@ class BinaryReduce(th.autograd.Function): ...@@ -401,6 +401,8 @@ class BinaryReduce(th.autograd.Function):
def backward(ctx, grad_out): def backward(ctx, grad_out):
reducer, binary_op, graph, lhs, rhs, lhs_map, rhs_map, out_map, \ reducer, binary_op, graph, lhs, rhs, lhs_map, rhs_map, out_map, \
feat_shape, degs = ctx.backward_cache feat_shape, degs = ctx.backward_cache
# See https://github.com/dmlc/dgl/pull/3386
ctx.backward_cache = None
lhs_data, rhs_data, out_data = ctx.saved_tensors lhs_data, rhs_data, out_data = ctx.saved_tensors
lhs_data_nd = zerocopy_to_dgl_ndarray(lhs_data) lhs_data_nd = zerocopy_to_dgl_ndarray(lhs_data)
rhs_data_nd = zerocopy_to_dgl_ndarray(rhs_data) rhs_data_nd = zerocopy_to_dgl_ndarray(rhs_data)
...@@ -478,6 +480,8 @@ class CopyReduce(th.autograd.Function): ...@@ -478,6 +480,8 @@ class CopyReduce(th.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, grad_out): def backward(ctx, grad_out):
reducer, graph, target, in_map, out_map, degs = ctx.backward_cache reducer, graph, target, in_map, out_map, degs = ctx.backward_cache
# See https://github.com/dmlc/dgl/pull/3386
ctx.backward_cache = None
in_data, out_data = ctx.saved_tensors in_data, out_data = ctx.saved_tensors
in_data_nd = zerocopy_to_dgl_ndarray(in_data) in_data_nd = zerocopy_to_dgl_ndarray(in_data)
out_data_nd = zerocopy_to_dgl_ndarray(out_data) out_data_nd = zerocopy_to_dgl_ndarray(out_data)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment