Unverified Commit 86fe58eb authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Bug] Another fix on PyTorch memory leakage issue (#1139)



* another fix

* another try

* fix

* rewriting with kernel functions

* revert mxnet softmax changes

* lint fix
Co-authored-by: default avatarZihao Ye <zihaoye.cs@gmail.com>
parent a199b58a
...@@ -445,7 +445,7 @@ class BinaryReduce(mx.autograd.Function): ...@@ -445,7 +445,7 @@ class BinaryReduce(mx.autograd.Function):
def binary_reduce(reducer, binary_op, graph, lhs, rhs, lhs_data, rhs_data, def binary_reduce(reducer, binary_op, graph, lhs, rhs, lhs_data, rhs_data,
out_size, lhs_map, rhs_map, out_map): out_size, lhs_map=(None, None), rhs_map=(None, None), out_map=(None, None)):
func = BinaryReduce(reducer, binary_op, graph, lhs, rhs, out_size, lhs_map, func = BinaryReduce(reducer, binary_op, graph, lhs, rhs, out_size, lhs_map,
rhs_map, out_map) rhs_map, out_map)
return func(lhs_data, rhs_data) return func(lhs_data, rhs_data)
...@@ -508,7 +508,8 @@ class CopyReduce(mx.autograd.Function): ...@@ -508,7 +508,8 @@ class CopyReduce(mx.autograd.Function):
return grad_in return grad_in
def copy_reduce(reducer, graph, target, in_data, out_size, in_map, out_map): def copy_reduce(reducer, graph, target, in_data, out_size, in_map=(None, None),
out_map=(None, None)):
func = CopyReduce(reducer, graph, target, out_size, in_map, out_map) func = CopyReduce(reducer, graph, target, out_size, in_map, out_map)
return func(in_data) return func(in_data)
......
...@@ -285,7 +285,7 @@ def zerocopy_from_dgl_ndarray(input): ...@@ -285,7 +285,7 @@ def zerocopy_from_dgl_ndarray(input):
class BinaryReduce(th.autograd.Function): class BinaryReduce(th.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, reducer, binary_op, graph, lhs, rhs, lhs_data, rhs_data, def forward(ctx, reducer, binary_op, graph, lhs, rhs, lhs_data, rhs_data, out_data,
out_size, lhs_map, rhs_map, out_map): out_size, lhs_map, rhs_map, out_map):
lhs_data_nd = zerocopy_to_dgl_ndarray(lhs_data) lhs_data_nd = zerocopy_to_dgl_ndarray(lhs_data)
rhs_data_nd = zerocopy_to_dgl_ndarray(rhs_data) rhs_data_nd = zerocopy_to_dgl_ndarray(rhs_data)
...@@ -293,7 +293,6 @@ class BinaryReduce(th.autograd.Function): ...@@ -293,7 +293,6 @@ class BinaryReduce(th.autograd.Function):
out_shape = feat_shape out_shape = feat_shape
if binary_op == 'dot': if binary_op == 'dot':
out_shape = feat_shape[:-1] out_shape = feat_shape[:-1]
out_data = lhs_data.new_empty((out_size,) + out_shape)
out_data_nd = zerocopy_to_dgl_ndarray(out_data) out_data_nd = zerocopy_to_dgl_ndarray(out_data)
K.binary_op_reduce( K.binary_op_reduce(
reducer if reducer != 'mean' else 'sum', reducer if reducer != 'mean' else 'sum',
...@@ -323,17 +322,17 @@ class BinaryReduce(th.autograd.Function): ...@@ -323,17 +322,17 @@ class BinaryReduce(th.autograd.Function):
degs = None degs = None
# save_for_backward can only save variables # save_for_backward can only save variables
ctx.backward_cache = (reducer, binary_op, graph, lhs, rhs, lhs_map, ctx.backward_cache = (reducer, binary_op, graph, lhs, rhs, lhs_map,
rhs_map, out_map, lhs_data_nd, rhs_data_nd, rhs_map, out_map, feat_shape, degs)
feat_shape, degs) ctx.save_for_backward(lhs_data, rhs_data, out_data)
ctx.save_for_backward(out_data)
return out_data return out_data
@staticmethod @staticmethod
def backward(ctx, grad_out): def backward(ctx, grad_out):
reducer, binary_op, graph, lhs, rhs, lhs_map, rhs_map, out_map, \ reducer, binary_op, graph, lhs, rhs, lhs_map, rhs_map, out_map, \
lhs_data_nd, rhs_data_nd, feat_shape, degs \ feat_shape, degs = ctx.backward_cache
= ctx.backward_cache lhs_data, rhs_data, out_data = ctx.saved_tensors
out_data, = ctx.saved_tensors lhs_data_nd = zerocopy_to_dgl_ndarray(lhs_data)
rhs_data_nd = zerocopy_to_dgl_ndarray(rhs_data)
out_data_nd = zerocopy_to_dgl_ndarray(out_data) out_data_nd = zerocopy_to_dgl_ndarray(out_data)
grad_lhs = None grad_lhs = None
grad_rhs = None grad_rhs = None
...@@ -357,19 +356,34 @@ class BinaryReduce(th.autograd.Function): ...@@ -357,19 +356,34 @@ class BinaryReduce(th.autograd.Function):
lhs_map[1], rhs_map[1], out_map[1]) lhs_map[1], rhs_map[1], out_map[1])
grad_rhs = _reduce_grad(grad_rhs, rhs_data_nd.shape) grad_rhs = _reduce_grad(grad_rhs, rhs_data_nd.shape)
return None, None, None, None, None, grad_lhs, grad_rhs, None, None, \ return None, None, None, None, None, grad_lhs, grad_rhs, None, None, None, \
None, None None, None
def binary_reduce(reducer, binary_op, graph, lhs, rhs, lhs_data, rhs_data,
out_size, lhs_map=(None, None), rhs_map=(None, None), out_map=(None, None)):
lhs_data_nd = zerocopy_to_dgl_ndarray(lhs_data)
rhs_data_nd = zerocopy_to_dgl_ndarray(rhs_data)
feat_shape = K.infer_binary_feature_shape(binary_op, lhs_data_nd, rhs_data_nd)
out_shape = feat_shape
if binary_op == 'dot':
out_shape = feat_shape[:-1]
out_data = lhs_data.new_empty((out_size,) + out_shape)
return BinaryReduce.apply(
reducer, binary_op, graph, lhs, rhs, lhs_data, rhs_data, out_data,
out_size, lhs_map, rhs_map, out_map)
class CopyReduce(th.autograd.Function): class CopyReduce(th.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, reducer, graph, target, in_data, out_size, in_map, def forward(ctx, reducer, graph, target, in_data, out_data, out_size, in_map,
out_map): out_map):
out_data = in_data.new_empty((out_size,) + in_data.shape[1:])
in_data_nd = zerocopy_to_dgl_ndarray(in_data) in_data_nd = zerocopy_to_dgl_ndarray(in_data)
out_data_nd = zerocopy_to_dgl_ndarray(out_data) out_data_nd = zerocopy_to_dgl_ndarray(out_data)
K.copy_reduce( K.copy_reduce(
reducer if reducer != 'mean' else 'sum', reducer if reducer != 'mean' else 'sum',
graph, target, in_data_nd, out_data_nd, in_map[0], out_map[0]) graph, target, in_data_nd, out_data_nd, in_map[0], out_map[0])
# normalize if mean reducer # normalize if mean reducer
# NOTE(zihao): this is a temporary hack and we should have better solution in the future. # NOTE(zihao): this is a temporary hack and we should have better solution in the future.
...@@ -379,23 +393,22 @@ class CopyReduce(th.autograd.Function): ...@@ -379,23 +393,22 @@ class CopyReduce(th.autograd.Function):
in_ones_nd = zerocopy_to_dgl_ndarray(in_ones) in_ones_nd = zerocopy_to_dgl_ndarray(in_ones)
degs_nd = zerocopy_to_dgl_ndarray(degs) degs_nd = zerocopy_to_dgl_ndarray(degs)
K.copy_reduce( K.copy_reduce(
'sum', graph, target, in_ones_nd, degs_nd, in_map[0], out_map[0]) 'sum', graph, target, in_ones_nd, degs_nd, in_map[0], out_map[0])
# reshape # reshape
degs = degs.reshape((out_data.shape[0],) + (1,) * (out_data.dim() - 1)).clamp(min=1) degs = degs.reshape((out_data.shape[0],) + (1,) * (out_data.dim() - 1)).clamp(min=1)
out_data = out_data / degs out_data = out_data / degs
else: else:
degs = None degs = None
# save_for_backward can only save variables # save_for_backward can only save variables
ctx.backward_cache = (reducer, graph, target, in_map, out_map, ctx.backward_cache = (reducer, graph, target, in_map, out_map, degs)
in_data_nd, degs) ctx.save_for_backward(in_data, out_data)
ctx.save_for_backward(out_data)
return out_data return out_data
@staticmethod @staticmethod
def backward(ctx, grad_out): def backward(ctx, grad_out):
reducer, graph, target, in_map, out_map, in_data_nd, degs \ reducer, graph, target, in_map, out_map, degs = ctx.backward_cache
= ctx.backward_cache in_data, out_data = ctx.saved_tensors
out_data, = ctx.saved_tensors in_data_nd = zerocopy_to_dgl_ndarray(in_data)
out_data_nd = zerocopy_to_dgl_ndarray(out_data) out_data_nd = zerocopy_to_dgl_ndarray(out_data)
grad_in = None grad_in = None
if reducer == 'mean': if reducer == 'mean':
...@@ -404,14 +417,16 @@ class CopyReduce(th.autograd.Function): ...@@ -404,14 +417,16 @@ class CopyReduce(th.autograd.Function):
if ctx.needs_input_grad[3]: if ctx.needs_input_grad[3]:
grad_in = grad_out.new_empty(in_data_nd.shape) grad_in = grad_out.new_empty(in_data_nd.shape)
K.backward_copy_reduce( K.backward_copy_reduce(
reducer if reducer != 'mean' else 'sum', reducer if reducer != 'mean' else 'sum',
graph, target, in_data_nd, out_data_nd, grad_out_nd, graph, target, in_data_nd, out_data_nd, grad_out_nd,
zerocopy_to_dgl_ndarray(grad_in), in_map[1], out_map[1]) zerocopy_to_dgl_ndarray(grad_in), in_map[1], out_map[1])
return None, None, None, grad_in, None, None, None return None, None, None, grad_in, None, None, None, None
binary_reduce = BinaryReduce.apply def copy_reduce(reducer, graph, target, in_data, out_size, in_map=(None, None),
copy_reduce = CopyReduce.apply out_map=(None, None)):
out_data = in_data.new_empty((out_size,) + in_data.shape[1:])
return CopyReduce.apply(reducer, graph, target, in_data, out_data, out_size, in_map, out_map)
def _reduce_grad(grad, shape): def _reduce_grad(grad, shape):
......
...@@ -385,7 +385,7 @@ def zerocopy_from_dgl_ndarray(input): ...@@ -385,7 +385,7 @@ def zerocopy_from_dgl_ndarray(input):
def binary_reduce(reducer, binary_op, graph, lhs, rhs, lhs_data, rhs_data, def binary_reduce(reducer, binary_op, graph, lhs, rhs, lhs_data, rhs_data,
out_size, lhs_map, rhs_map, out_map): out_size, lhs_map=(None, None), rhs_map=(None, None), out_map=(None, None)):
@tf.custom_gradient @tf.custom_gradient
def _lambda(lhs_data, rhs_data): def _lambda(lhs_data, rhs_data):
...@@ -467,13 +467,13 @@ def binary_reduce_real(reducer, binary_op, graph, lhs, rhs, lhs_data, rhs_data, ...@@ -467,13 +467,13 @@ def binary_reduce_real(reducer, binary_op, graph, lhs, rhs, lhs_data, rhs_data,
return out_data, grad return out_data, grad
def copy_reduce(reducer, graph, target, in_data, out_size, in_map, def copy_reduce(reducer, graph, target, in_data, out_size, in_map=(None, None),
out_map): out_map=(None, None)):
@tf.custom_gradient @tf.custom_gradient
def _labmda(in_data): def _lambda(in_data):
return copy_reduce_real(reducer, graph, target, in_data, out_size, in_map, return copy_reduce_real(reducer, graph, target, in_data, out_size, in_map,
out_map) out_map)
return _labmda(in_data) return _lambda(in_data)
def copy_reduce_real(reducer, graph, target, in_data, out_size, in_map, def copy_reduce_real(reducer, graph, target, in_data, out_size, in_map,
......
...@@ -46,9 +46,9 @@ def binary_op_reduce(reducer, op, G, A_target, B_target, A, B, out, ...@@ -46,9 +46,9 @@ def binary_op_reduce(reducer, op, G, A_target, B_target, A, B, out,
node ID, destination node ID, or edge ID, according to ``A_target`` node ID, destination node ID, or edge ID, according to ``A_target``
and ``B_target`` which could take either and ``B_target`` which could take either
- "source" (0), - "source" (dgl.function.TargetCode.SRC),
- "destination" (1), or - "destination" (dgl.function.TargetCode.DST), or
- "edge" (2). - "edge" (dgl.function.TargetCode.EDGE).
* ``A`` and ``B`` are data tensors. If ``A_target`` is "edge", then * ``A`` and ``B`` are data tensors. If ``A_target`` is "edge", then
``A.shape[0]`` should equal the number of edges of ``G``. Otherwise ``A.shape[0]`` should equal the number of edges of ``G``. Otherwise
...@@ -298,9 +298,9 @@ def copy_reduce(reducer, G, target, ...@@ -298,9 +298,9 @@ def copy_reduce(reducer, G, target,
* ``select_target`` would return the source node ID, destination node, * ``select_target`` would return the source node ID, destination node,
ID, or edge ID, according to ``target`` which could take either ID, or edge ID, according to ``target`` which could take either
- "source" (0), - "source" (dgl.function.TargetCode.SRC),
- "destination" (1), or - "destination" (dgl.function.TargetCode.DST), or
- "edge" (2) - "edge" (dgl.function.TargetCode.EDGE).
* ``X`` is a data tensor. If ``target`` is "edge", then ``X.shape[0]`` * ``X`` is a data tensor. If ``target`` is "edge", then ``X.shape[0]``
should equal the number of edges of ``G``. Otherwise that should should equal the number of edges of ``G``. Otherwise that should
......
...@@ -2,8 +2,10 @@ ...@@ -2,8 +2,10 @@
# pylint: disable= no-member, arguments-differ # pylint: disable= no-member, arguments-differ
import torch as th import torch as th
from ... import function as fn from ...function import TargetCode
from ...base import ALL, is_all from ...base import ALL, is_all
from ... import backend as F
from ... import utils
__all__ = ['edge_softmax'] __all__ = ['edge_softmax']
...@@ -44,15 +46,25 @@ class EdgeSoftmax(th.autograd.Function): ...@@ -44,15 +46,25 @@ class EdgeSoftmax(th.autograd.Function):
# a local variable # a local variable
if not is_all(eids): if not is_all(eids):
g = g.edge_subgraph(eids.long()) g = g.edge_subgraph(eids.long())
ctx.backward_cache = g
g = g.local_var() n_nodes = g.number_of_nodes()
g.edata['s'] = score n_edges = g.number_of_edges()
g.update_all(fn.copy_e('s', 'm'), fn.max('m', 'smax')) gidx = g._graph.get_immutable_gidx(utils.to_dgl_context(score.device))
g.apply_edges(fn.e_sub_v('s', 'smax', 'out')) ctx.backward_cache = n_nodes, n_edges, gidx
g.edata['out'] = th.exp(g.edata['out'])
g.update_all(fn.copy_e('out', 'm'), fn.sum('m', 'out_sum')) #g.update_all(fn.copy_e('s', 'm'), fn.max('m', 'smax'))
g.apply_edges(fn.e_div_v('out', 'out_sum', 'out')) smax = F.copy_reduce('max', gidx, TargetCode.EDGE, score, n_nodes)
out = g.edata['out'] #g.apply_edges(fn.e_sub_v('s', 'smax', 'out'))
out = F.binary_reduce(
'none', 'sub', gidx, TargetCode.EDGE, TargetCode.DST, score, smax, n_edges)
#g.edata['out'] = th.exp(g.edata['out'])
out = th.exp(out)
#g.update_all(fn.copy_e('out', 'm'), fn.sum('m', 'out_sum'))
out_sum = F.copy_reduce('sum', gidx, TargetCode.EDGE, out, n_nodes)
#g.apply_edges(fn.e_div_v('out', 'out_sum', 'out'))
out = F.binary_reduce(
'none', 'div', gidx, TargetCode.EDGE, TargetCode.DST, out, out_sum, n_edges)
ctx.save_for_backward(out) ctx.save_for_backward(out)
return out return out
...@@ -72,16 +84,19 @@ class EdgeSoftmax(th.autograd.Function): ...@@ -72,16 +84,19 @@ class EdgeSoftmax(th.autograd.Function):
grad_score = sds - sds * sds_sum # multiple expressions grad_score = sds - sds * sds_sum # multiple expressions
return grad_score.data return grad_score.data
""" """
g = ctx.backward_cache n_nodes, n_edges, gidx = ctx.backward_cache
g = g.local_var()
out, = ctx.saved_tensors out, = ctx.saved_tensors
# clear backward cache explicitly
ctx.backward_cache = None #g.edata['grad_s'] = out * grad_out
g.edata['out'] = out grad_s = out * grad_out
g.edata['grad_s'] = out * grad_out #g.update_all(fn.copy_e('grad_s', 'm'), fn.sum('m', 'accum'))
g.update_all(fn.copy_e('grad_s', 'm'), fn.sum('m', 'accum')) accum = F.copy_reduce('sum', gidx, TargetCode.EDGE, grad_s, n_nodes)
g.apply_edges(fn.e_mul_v('out', 'accum', 'out')) #g.apply_edges(fn.e_mul_v('out', 'accum', 'out'))
grad_score = g.edata['grad_s'] - g.edata['out'] out = F.binary_reduce(
'none', 'mul', gidx, TargetCode.EDGE, TargetCode.DST, out, accum, n_edges)
#grad_score = g.edata['grad_s'] - g.edata['out']
grad_score = grad_s - out
return None, grad_score, None return None, grad_score, None
......
...@@ -303,8 +303,8 @@ def test_all_binary_builtins(): ...@@ -303,8 +303,8 @@ def test_all_binary_builtins():
def _print_error(a, b): def _print_error(a, b):
print("ERROR: Test {}_{}_{}_{} broadcast: {} partial: {}". print("ERROR: Test {}_{}_{}_{} broadcast: {} partial: {}".
format(lhs, binary_op, rhs, reducer, broadcast, partial)) format(lhs, binary_op, rhs, reducer, broadcast, partial))
print("lhs", F.asnumpy(lhs).tolist()) print("lhs", lhs)
print("rhs", F.asnumpy(rhs).tolist()) print("rhs", rhs)
for i, (x, y) in enumerate(zip(F.asnumpy(a).flatten(), F.asnumpy(b).flatten())): for i, (x, y) in enumerate(zip(F.asnumpy(a).flatten(), F.asnumpy(b).flatten())):
if not np.allclose(x, y, rtol, atol): if not np.allclose(x, y, rtol, atol):
print('@{} {} v.s. {}'.format(i, x, y)) print('@{} {} v.s. {}'.format(i, x, y))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment