Unverified Commit 49b406c9 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

memory leak fix in PyTorch (#1060)

parent 22d4de77
...@@ -326,15 +326,17 @@ class BinaryReduce(th.autograd.Function): ...@@ -326,15 +326,17 @@ class BinaryReduce(th.autograd.Function):
# save_for_backward can only save variables # save_for_backward can only save variables
ctx.backward_cache = (reducer, binary_op, graph, lhs, rhs, lhs_map, ctx.backward_cache = (reducer, binary_op, graph, lhs, rhs, lhs_map,
rhs_map, out_map, lhs_data_nd, rhs_data_nd, rhs_map, out_map, lhs_data_nd, rhs_data_nd,
out_data_nd, feat_shape, degs) feat_shape, degs)
ctx.save_for_backward(out_data)
return out_data return out_data
@staticmethod @staticmethod
def backward(ctx, grad_out): def backward(ctx, grad_out):
reducer, binary_op, graph, lhs, rhs, lhs_map, rhs_map, out_map, \ reducer, binary_op, graph, lhs, rhs, lhs_map, rhs_map, out_map, \
lhs_data_nd, rhs_data_nd, out_data_nd, feat_shape, degs \ lhs_data_nd, rhs_data_nd, feat_shape, degs \
= ctx.backward_cache = ctx.backward_cache
ctx.backward_cache = None out_data, = ctx.saved_variables
out_data_nd = zerocopy_to_dgl_ndarray(out_data)
grad_lhs = None grad_lhs = None
grad_rhs = None grad_rhs = None
if reducer == 'mean': if reducer == 'mean':
...@@ -387,14 +389,16 @@ class CopyReduce(th.autograd.Function): ...@@ -387,14 +389,16 @@ class CopyReduce(th.autograd.Function):
degs = None degs = None
# save_for_backward can only save variables # save_for_backward can only save variables
ctx.backward_cache = (reducer, graph, target, in_map, out_map, ctx.backward_cache = (reducer, graph, target, in_map, out_map,
in_data_nd, out_data_nd, degs) in_data_nd, degs)
ctx.save_for_backward(out_data)
return out_data return out_data
@staticmethod @staticmethod
def backward(ctx, grad_out): def backward(ctx, grad_out):
reducer, graph, target, in_map, out_map, in_data_nd, out_data_nd, degs \ reducer, graph, target, in_map, out_map, in_data_nd, degs \
= ctx.backward_cache = ctx.backward_cache
ctx.backward_cache = None out_data, = ctx.saved_variables
out_data_nd = zerocopy_to_dgl_ndarray(out_data)
grad_in = None grad_in = None
if reducer == 'mean': if reducer == 'mean':
grad_out = grad_out / degs grad_out = grad_out / degs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment