Unverified Commit 86fe58eb authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Bug] Another fix on PyTorch memory leakage issue (#1139)



* another fix

* another try

* fix

* rewriting with kernel functions

* revert mxnet softmax changes

* lint fix
Co-authored-by: default avatarZihao Ye <zihaoye.cs@gmail.com>
parent a199b58a
......@@ -445,7 +445,7 @@ class BinaryReduce(mx.autograd.Function):
def binary_reduce(reducer, binary_op, graph, lhs, rhs, lhs_data, rhs_data,
out_size, lhs_map, rhs_map, out_map):
out_size, lhs_map=(None, None), rhs_map=(None, None), out_map=(None, None)):
func = BinaryReduce(reducer, binary_op, graph, lhs, rhs, out_size, lhs_map,
rhs_map, out_map)
return func(lhs_data, rhs_data)
......@@ -508,7 +508,8 @@ class CopyReduce(mx.autograd.Function):
return grad_in
def copy_reduce(reducer, graph, target, in_data, out_size, in_map, out_map):
def copy_reduce(reducer, graph, target, in_data, out_size, in_map=(None, None),
out_map=(None, None)):
func = CopyReduce(reducer, graph, target, out_size, in_map, out_map)
return func(in_data)
......
......@@ -285,7 +285,7 @@ def zerocopy_from_dgl_ndarray(input):
class BinaryReduce(th.autograd.Function):
@staticmethod
def forward(ctx, reducer, binary_op, graph, lhs, rhs, lhs_data, rhs_data,
def forward(ctx, reducer, binary_op, graph, lhs, rhs, lhs_data, rhs_data, out_data,
out_size, lhs_map, rhs_map, out_map):
lhs_data_nd = zerocopy_to_dgl_ndarray(lhs_data)
rhs_data_nd = zerocopy_to_dgl_ndarray(rhs_data)
......@@ -293,7 +293,6 @@ class BinaryReduce(th.autograd.Function):
out_shape = feat_shape
if binary_op == 'dot':
out_shape = feat_shape[:-1]
out_data = lhs_data.new_empty((out_size,) + out_shape)
out_data_nd = zerocopy_to_dgl_ndarray(out_data)
K.binary_op_reduce(
reducer if reducer != 'mean' else 'sum',
......@@ -323,17 +322,17 @@ class BinaryReduce(th.autograd.Function):
degs = None
# save_for_backward can only save variables
ctx.backward_cache = (reducer, binary_op, graph, lhs, rhs, lhs_map,
rhs_map, out_map, lhs_data_nd, rhs_data_nd,
feat_shape, degs)
ctx.save_for_backward(out_data)
rhs_map, out_map, feat_shape, degs)
ctx.save_for_backward(lhs_data, rhs_data, out_data)
return out_data
@staticmethod
def backward(ctx, grad_out):
reducer, binary_op, graph, lhs, rhs, lhs_map, rhs_map, out_map, \
lhs_data_nd, rhs_data_nd, feat_shape, degs \
= ctx.backward_cache
out_data, = ctx.saved_tensors
feat_shape, degs = ctx.backward_cache
lhs_data, rhs_data, out_data = ctx.saved_tensors
lhs_data_nd = zerocopy_to_dgl_ndarray(lhs_data)
rhs_data_nd = zerocopy_to_dgl_ndarray(rhs_data)
out_data_nd = zerocopy_to_dgl_ndarray(out_data)
grad_lhs = None
grad_rhs = None
......@@ -357,19 +356,34 @@ class BinaryReduce(th.autograd.Function):
lhs_map[1], rhs_map[1], out_map[1])
grad_rhs = _reduce_grad(grad_rhs, rhs_data_nd.shape)
return None, None, None, None, None, grad_lhs, grad_rhs, None, None, \
return None, None, None, None, None, grad_lhs, grad_rhs, None, None, None, \
None, None
def binary_reduce(reducer, binary_op, graph, lhs, rhs, lhs_data, rhs_data,
out_size, lhs_map=(None, None), rhs_map=(None, None), out_map=(None, None)):
lhs_data_nd = zerocopy_to_dgl_ndarray(lhs_data)
rhs_data_nd = zerocopy_to_dgl_ndarray(rhs_data)
feat_shape = K.infer_binary_feature_shape(binary_op, lhs_data_nd, rhs_data_nd)
out_shape = feat_shape
if binary_op == 'dot':
out_shape = feat_shape[:-1]
out_data = lhs_data.new_empty((out_size,) + out_shape)
return BinaryReduce.apply(
reducer, binary_op, graph, lhs, rhs, lhs_data, rhs_data, out_data,
out_size, lhs_map, rhs_map, out_map)
class CopyReduce(th.autograd.Function):
@staticmethod
def forward(ctx, reducer, graph, target, in_data, out_size, in_map,
def forward(ctx, reducer, graph, target, in_data, out_data, out_size, in_map,
out_map):
out_data = in_data.new_empty((out_size,) + in_data.shape[1:])
in_data_nd = zerocopy_to_dgl_ndarray(in_data)
out_data_nd = zerocopy_to_dgl_ndarray(out_data)
K.copy_reduce(
reducer if reducer != 'mean' else 'sum',
reducer if reducer != 'mean' else 'sum',
graph, target, in_data_nd, out_data_nd, in_map[0], out_map[0])
# normalize if mean reducer
# NOTE(zihao): this is a temporary hack and we should have better solution in the future.
......@@ -379,23 +393,22 @@ class CopyReduce(th.autograd.Function):
in_ones_nd = zerocopy_to_dgl_ndarray(in_ones)
degs_nd = zerocopy_to_dgl_ndarray(degs)
K.copy_reduce(
'sum', graph, target, in_ones_nd, degs_nd, in_map[0], out_map[0])
'sum', graph, target, in_ones_nd, degs_nd, in_map[0], out_map[0])
# reshape
degs = degs.reshape((out_data.shape[0],) + (1,) * (out_data.dim() - 1)).clamp(min=1)
out_data = out_data / degs
else:
degs = None
# save_for_backward can only save variables
ctx.backward_cache = (reducer, graph, target, in_map, out_map,
in_data_nd, degs)
ctx.save_for_backward(out_data)
ctx.backward_cache = (reducer, graph, target, in_map, out_map, degs)
ctx.save_for_backward(in_data, out_data)
return out_data
@staticmethod
def backward(ctx, grad_out):
reducer, graph, target, in_map, out_map, in_data_nd, degs \
= ctx.backward_cache
out_data, = ctx.saved_tensors
reducer, graph, target, in_map, out_map, degs = ctx.backward_cache
in_data, out_data = ctx.saved_tensors
in_data_nd = zerocopy_to_dgl_ndarray(in_data)
out_data_nd = zerocopy_to_dgl_ndarray(out_data)
grad_in = None
if reducer == 'mean':
......@@ -404,14 +417,16 @@ class CopyReduce(th.autograd.Function):
if ctx.needs_input_grad[3]:
grad_in = grad_out.new_empty(in_data_nd.shape)
K.backward_copy_reduce(
reducer if reducer != 'mean' else 'sum',
graph, target, in_data_nd, out_data_nd, grad_out_nd,
reducer if reducer != 'mean' else 'sum',
graph, target, in_data_nd, out_data_nd, grad_out_nd,
zerocopy_to_dgl_ndarray(grad_in), in_map[1], out_map[1])
return None, None, None, grad_in, None, None, None
return None, None, None, grad_in, None, None, None, None
binary_reduce = BinaryReduce.apply
copy_reduce = CopyReduce.apply
def copy_reduce(reducer, graph, target, in_data, out_size, in_map=(None, None),
out_map=(None, None)):
out_data = in_data.new_empty((out_size,) + in_data.shape[1:])
return CopyReduce.apply(reducer, graph, target, in_data, out_data, out_size, in_map, out_map)
def _reduce_grad(grad, shape):
......
......@@ -385,7 +385,7 @@ def zerocopy_from_dgl_ndarray(input):
def binary_reduce(reducer, binary_op, graph, lhs, rhs, lhs_data, rhs_data,
out_size, lhs_map, rhs_map, out_map):
out_size, lhs_map=(None, None), rhs_map=(None, None), out_map=(None, None)):
@tf.custom_gradient
def _lambda(lhs_data, rhs_data):
......@@ -467,13 +467,13 @@ def binary_reduce_real(reducer, binary_op, graph, lhs, rhs, lhs_data, rhs_data,
return out_data, grad
def copy_reduce(reducer, graph, target, in_data, out_size, in_map,
out_map):
def copy_reduce(reducer, graph, target, in_data, out_size, in_map=(None, None),
out_map=(None, None)):
@tf.custom_gradient
def _labmda(in_data):
def _lambda(in_data):
return copy_reduce_real(reducer, graph, target, in_data, out_size, in_map,
out_map)
return _labmda(in_data)
return _lambda(in_data)
def copy_reduce_real(reducer, graph, target, in_data, out_size, in_map,
......
......@@ -46,9 +46,9 @@ def binary_op_reduce(reducer, op, G, A_target, B_target, A, B, out,
node ID, destination node ID, or edge ID, according to ``A_target``
and ``B_target`` which could take either
- "source" (0),
- "destination" (1), or
- "edge" (2).
- "source" (dgl.function.TargetCode.SRC),
- "destination" (dgl.function.TargetCode.DST), or
- "edge" (dgl.function.TargetCode.EDGE).
* ``A`` and ``B`` are data tensors. If ``A_target`` is "edge", then
``A.shape[0]`` should equal the number of edges of ``G``. Otherwise
......@@ -298,9 +298,9 @@ def copy_reduce(reducer, G, target,
* ``select_target`` would return the source node ID, destination node,
ID, or edge ID, according to ``target`` which could take either
- "source" (0),
- "destination" (1), or
- "edge" (2)
- "source" (dgl.function.TargetCode.SRC),
- "destination" (dgl.function.TargetCode.DST), or
- "edge" (dgl.function.TargetCode.EDGE).
* ``X`` is a data tensor. If ``target`` is "edge", then ``X.shape[0]``
should equal the number of edges of ``G``. Otherwise that should
......
......@@ -2,8 +2,10 @@
# pylint: disable= no-member, arguments-differ
import torch as th
from ... import function as fn
from ...function import TargetCode
from ...base import ALL, is_all
from ... import backend as F
from ... import utils
__all__ = ['edge_softmax']
......@@ -44,15 +46,25 @@ class EdgeSoftmax(th.autograd.Function):
# a local variable
if not is_all(eids):
g = g.edge_subgraph(eids.long())
ctx.backward_cache = g
g = g.local_var()
g.edata['s'] = score
g.update_all(fn.copy_e('s', 'm'), fn.max('m', 'smax'))
g.apply_edges(fn.e_sub_v('s', 'smax', 'out'))
g.edata['out'] = th.exp(g.edata['out'])
g.update_all(fn.copy_e('out', 'm'), fn.sum('m', 'out_sum'))
g.apply_edges(fn.e_div_v('out', 'out_sum', 'out'))
out = g.edata['out']
n_nodes = g.number_of_nodes()
n_edges = g.number_of_edges()
gidx = g._graph.get_immutable_gidx(utils.to_dgl_context(score.device))
ctx.backward_cache = n_nodes, n_edges, gidx
#g.update_all(fn.copy_e('s', 'm'), fn.max('m', 'smax'))
smax = F.copy_reduce('max', gidx, TargetCode.EDGE, score, n_nodes)
#g.apply_edges(fn.e_sub_v('s', 'smax', 'out'))
out = F.binary_reduce(
'none', 'sub', gidx, TargetCode.EDGE, TargetCode.DST, score, smax, n_edges)
#g.edata['out'] = th.exp(g.edata['out'])
out = th.exp(out)
#g.update_all(fn.copy_e('out', 'm'), fn.sum('m', 'out_sum'))
out_sum = F.copy_reduce('sum', gidx, TargetCode.EDGE, out, n_nodes)
#g.apply_edges(fn.e_div_v('out', 'out_sum', 'out'))
out = F.binary_reduce(
'none', 'div', gidx, TargetCode.EDGE, TargetCode.DST, out, out_sum, n_edges)
ctx.save_for_backward(out)
return out
......@@ -72,16 +84,19 @@ class EdgeSoftmax(th.autograd.Function):
grad_score = sds - sds * sds_sum # multiple expressions
return grad_score.data
"""
g = ctx.backward_cache
g = g.local_var()
n_nodes, n_edges, gidx = ctx.backward_cache
out, = ctx.saved_tensors
# clear backward cache explicitly
ctx.backward_cache = None
g.edata['out'] = out
g.edata['grad_s'] = out * grad_out
g.update_all(fn.copy_e('grad_s', 'm'), fn.sum('m', 'accum'))
g.apply_edges(fn.e_mul_v('out', 'accum', 'out'))
grad_score = g.edata['grad_s'] - g.edata['out']
#g.edata['grad_s'] = out * grad_out
grad_s = out * grad_out
#g.update_all(fn.copy_e('grad_s', 'm'), fn.sum('m', 'accum'))
accum = F.copy_reduce('sum', gidx, TargetCode.EDGE, grad_s, n_nodes)
#g.apply_edges(fn.e_mul_v('out', 'accum', 'out'))
out = F.binary_reduce(
'none', 'mul', gidx, TargetCode.EDGE, TargetCode.DST, out, accum, n_edges)
#grad_score = g.edata['grad_s'] - g.edata['out']
grad_score = grad_s - out
return None, grad_score, None
......
......@@ -303,8 +303,8 @@ def test_all_binary_builtins():
def _print_error(a, b):
print("ERROR: Test {}_{}_{}_{} broadcast: {} partial: {}".
format(lhs, binary_op, rhs, reducer, broadcast, partial))
print("lhs", F.asnumpy(lhs).tolist())
print("rhs", F.asnumpy(rhs).tolist())
print("lhs", lhs)
print("rhs", rhs)
for i, (x, y) in enumerate(zip(F.asnumpy(a).flatten(), F.asnumpy(b).flatten())):
if not np.allclose(x, y, rtol, atol):
print('@{} {} v.s. {}'.format(i, x, y))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment