Unverified Commit 76af2a2e authored by Zihao Ye's avatar Zihao Ye Committed by GitHub
Browse files

[perf] Remove activation cache if not required. (#3258)

* upd

* fix

* upd
parent ac01e880
......@@ -66,21 +66,54 @@ def _need_reduce_last_dim(ufeat, efeat):
"""Indicates whether to reduce the last dimension on edges
in the backward pass of spmm,
if so, use dot instead of mul."""
if ufeat is None or efeat is None:
return False
ushp = ufeat.shape
eshp = efeat.shape
return ushp[1:-1] == eshp[1:-1] and eshp[-1] == 1 and ushp[-1] > 1
def _muldiv(op, x):
return 1. / x if op == 'div' else x
def _expand(x, shape):
return x.expand(-1, *shape)
def _addsub(op, x):
return -x if op == 'sub' else x
def spmm_cache_X(binary_op, reduce_op, req_grad_X, req_grad_Y):
"""Rules to identify whether to cache X in SpMM forward stage."""
if binary_op != 'copy_lhs' and req_grad_Y:
if reduce_op == 'sum':
return True
else:
if binary_op == 'mul':
return True
return False
def _expand(x, shape):
return x.expand(-1, *shape)
def spmm_cache_Y(binary_op, reduce_op, req_grad_X, req_grad_Y):
"""Rules to identify whether to cache Y in SpMM forward stage."""
if binary_op != 'copy_rhs' and req_grad_X:
if reduce_op == 'sum':
if binary_op in ['mul', 'add']:
return True
else:
if binary_op == 'mul':
return True
return False
def spmm_cache_argX(binary_op, reduce_op, req_grad_X, req_grad_Y):
"""Rules to identify whether to cache argX in SpMM forward stage."""
if req_grad_X:
if reduce_op in ['min', 'max']:
return True
return False
def spmm_cache_argY(binary_op, reduce_op, req_grad_X, req_grad_Y):
"""Rules to identify whether to cache argY in SpMM forward stage."""
if req_grad_Y:
if reduce_op in ['min', 'max']:
return True
return False
class GSpMM(th.autograd.Function):
......@@ -88,58 +121,69 @@ class GSpMM(th.autograd.Function):
@custom_fwd(cast_inputs=th.float16)
def forward(ctx, gidx, op, reduce_op, X, Y):
out, (argX, argY) = _gspmm(gidx, op, reduce_op, X, Y)
ctx.backward_cache = gidx, op, reduce_op
reduce_last = _need_reduce_last_dim(X, Y)
X_shape = X.shape if X is not None else None
Y_shape = Y.shape if Y is not None else None
dtype = X.dtype if X is not None else Y.dtype
device = X.device if X is not None else Y.device
ctx.backward_cache = gidx, op, reduce_op, X_shape, Y_shape, dtype, device, reduce_last
req_grad_X = X.requires_grad if X is not None else False
req_grad_Y = Y.requires_grad if Y is not None else False
if not spmm_cache_X(op, reduce_op, req_grad_X, req_grad_Y):
X = None
if not spmm_cache_Y(op, reduce_op, req_grad_X, req_grad_Y):
Y = None
if not spmm_cache_argX(op, reduce_op, req_grad_X, req_grad_Y):
argX = None
if not spmm_cache_argY(op, reduce_op, req_grad_X, req_grad_Y):
argY = None
ctx.save_for_backward(X, Y, argX, argY)
return out
@staticmethod
@custom_bwd
def backward(ctx, dZ):
gidx, op, reduce_op = ctx.backward_cache
gidx, op, reduce_op, X_shape, Y_shape, dtype, device, reduce_last = ctx.backward_cache
X, Y, argX, argY = ctx.saved_tensors
if op != 'copy_rhs' and ctx.needs_input_grad[3]:
g_rev = gidx.reverse()
if reduce_op == 'sum':
if op in ['mul', 'div']:
dX = gspmm(g_rev, 'mul', 'sum', dZ, _muldiv(op, Y))
elif op in ['add', 'sub']:
if op == 'mul':
dX = gspmm(g_rev, 'mul', 'sum', dZ, Y)
elif op == 'add':
dX = gspmm(g_rev, 'copy_lhs', 'sum', dZ, Y)
elif op == 'copy_lhs':
dX = gspmm(g_rev, 'copy_lhs', 'sum', dZ, None)
else: # max/min
dX = th.zeros((X.shape[0],) + dZ.shape[1:],
dtype=X.dtype, device=X.device)
if op in ['mul', 'div']:
grad = _muldiv(op, _expand(Y, dZ.shape[1:]).gather(
0, argY.long())) * dZ
dX = th.zeros((X_shape[0],) + dZ.shape[1:],
dtype=dtype, device=device)
if op == 'mul':
grad = _expand(Y, dZ.shape[1:]).gather(
0, argY.long()) * dZ
dX.scatter_add_(0, argX.long(), grad)
elif op in ['add', 'sub', 'copy_lhs']:
elif op in ['add', 'copy_lhs']:
dX.scatter_add_(0, argX.long(), dZ)
dX = _reduce_grad(dX, X.shape)
dX = _reduce_grad(dX, X_shape)
else: # X has not gradient
dX = None
if op != 'copy_lhs' and ctx.needs_input_grad[4]:
if reduce_op == 'sum':
if op == 'mul' and _need_reduce_last_dim(X, Y):
if op == 'mul' and reduce_last:
dY = gsddmm(gidx, 'dot', X, dZ)
elif op in ['mul', 'div']:
elif op == 'mul':
dY = gsddmm(gidx, 'mul', X, dZ)
if op == 'div':
dY = -dY / (Y ** 2)
elif op in ['add', 'sub', 'copy_rhs']:
dY = gsddmm(gidx, 'copy_rhs', X, _addsub(op, dZ))
elif op in ['add', 'copy_rhs']:
dY = gsddmm(gidx, 'copy_rhs', X, dZ)
else: # max/min
dY = th.zeros((Y.shape[0],) + dZ.shape[1:],
dtype=Y.dtype, device=Y.device)
if op in ['mul', 'div']:
dY = th.zeros((Y_shape[0],) + dZ.shape[1:],
dtype=dtype, device=device)
if op == 'mul':
grad = _expand(X, dZ.shape[1:]).gather(
0, argX.long()) * dZ
dY.scatter_add_(0, argY.long(), grad)
if op == 'div':
dY = -dY / (Y ** 2)
elif op in ['add', 'sub', 'copy_rhs']:
dY.scatter_add_(0, argY.long(), _addsub(op, dZ))
dY = _reduce_grad(dY, Y.shape)
elif op in ['add', 'copy_rhs']:
dY.scatter_add_(0, argY.long(), dZ)
dY = _reduce_grad(dY, Y_shape)
else: # Y has no gradient
dY = None
return None, None, None, dX, dY
......@@ -178,7 +222,7 @@ class GSpMM_hetero(th.autograd.Function):
# TODO(Israt): implement other combinations of message and reduce functions
if reduce_op == 'sum':
if op in ['copy_rhs']:
tmp_Z = tuple([_addsub(op, dZ[i]) if dZ[i] is not None else None
tmp_Z = tuple([dZ[i] if dZ[i] is not None else None
for i in range(len(dZ))])
tmp = tuple(X + tmp_Z)
dY = gsddmm_hetero(g, 'copy_rhs', 'u', 'v', *tmp)
......@@ -188,62 +232,81 @@ class GSpMM_hetero(th.autograd.Function):
dY = tuple([None] * len(Y))
return (None, None, None) + dX + dY
def sddmm_cache_X(op, req_grad_X, req_grad_Y):
"""Rules to identify whether to cache X in SDDMM forward stage."""
if op in ['mul', 'dot'] and req_grad_Y:
return True
return False
def sddmm_cache_Y(op, req_grad_X, req_grad_Y):
"""Rules to identify whether to cache Y in SDDMM forward stage."""
if op in ['mul', 'dot'] and req_grad_X:
return True
return False
class GSDDMM(th.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=th.float16)
def forward(ctx, gidx, op, X, Y, lhs_target, rhs_target):
out = _gsddmm(gidx, op, X, Y, lhs_target, rhs_target)
ctx.backward_cache = gidx, op, lhs_target, rhs_target
X_shape = X.shape if X is not None else None
Y_shape = Y.shape if Y is not None else None
ctx.backward_cache = gidx, op, lhs_target, rhs_target, X_shape, Y_shape
req_grad_X = X.requires_grad if X is not None else False
req_grad_Y = Y.requires_grad if Y is not None else False
if not sddmm_cache_X(op, req_grad_X, req_grad_Y):
X = None
if not sddmm_cache_Y(op, req_grad_X, req_grad_Y):
Y = None
ctx.save_for_backward(X, Y)
return out
@staticmethod
@custom_bwd
def backward(ctx, dZ):
gidx, op, lhs_target, rhs_target = ctx.backward_cache
gidx, op, lhs_target, rhs_target, X_shape, Y_shape = ctx.backward_cache
X, Y = ctx.saved_tensors
if op != 'copy_rhs' and ctx.needs_input_grad[2]:
if lhs_target in ['u', 'v']:
_gidx = gidx if lhs_target == 'v' else gidx.reverse()
if op in ['add', 'sub', 'copy_lhs']:
if op in ['add', 'copy_lhs']:
dX = gspmm(_gidx, 'copy_rhs', 'sum', None, dZ)
else: # mul, div, dot
else: # mul, dot
if rhs_target == lhs_target:
dX = gspmm(_gidx, 'copy_rhs', 'sum', None, dZ) * _muldiv(op, Y)
dX = gspmm(_gidx, 'copy_rhs', 'sum', None, dZ) * Y
elif rhs_target == 'e':
dX = gspmm(_gidx, 'copy_rhs', 'sum', None, dZ * _muldiv(op, Y))
dX = gspmm(_gidx, 'copy_rhs', 'sum', None, dZ * Y)
else: # rhs_target = !lhs_target
dX = gspmm(_gidx, 'mul', 'sum', _muldiv(op, Y), dZ)
dX = gspmm(_gidx, 'mul', 'sum', Y, dZ)
else: # lhs_target == 'e'
if op in ['add', 'sub', 'copy_lhs']:
if op in ['add', 'copy_lhs']:
dX = dZ
else: # mul, div, dot
dX = gsddmm(gidx, 'mul', dZ, _muldiv(op, Y), 'e', rhs_target)
dX = _reduce_grad(dX, X.shape)
else: # mul, dot
dX = gsddmm(gidx, 'mul', dZ, Y, 'e', rhs_target)
dX = _reduce_grad(dX, X_shape)
else:
dX = None
if op != 'copy_lhs' and ctx.needs_input_grad[3]:
if rhs_target in ['u', 'v']:
_gidx = gidx if rhs_target == 'v' else gidx.reverse()
if op in ['add', 'sub', 'copy_rhs']:
dY = gspmm(_gidx, 'copy_rhs', 'sum', None, _addsub(op, dZ))
else: # mul, div, dot
if op in ['add', 'copy_rhs']:
dY = gspmm(_gidx, 'copy_rhs', 'sum', None, dZ)
else: # mul, dot
if lhs_target == rhs_target:
dY = gspmm(_gidx, 'copy_rhs', 'sum', None, dZ) * X
elif lhs_target == 'e':
dY = gspmm(_gidx, 'copy_rhs', 'sum', None, dZ * X)
else: # rhs_target = !lhs_target
dY = gspmm(_gidx, 'mul', 'sum', X, dZ)
if op == 'div':
dY = -dY / (Y ** 2)
else:
if op in ['add', 'sub', 'copy_rhs']:
dY = _addsub(op, dZ)
else: # mul, div, dot
if op in ['add', 'copy_rhs']:
dY = dZ
else: # mul, dot
dY = gsddmm(gidx, 'mul', dZ, X, 'e', lhs_target)
if op == 'div':
dY = -dY / (Y ** 2)
dY = _reduce_grad(dY, Y.shape)
dY = _reduce_grad(dY, Y_shape)
else:
dY = None
return None, None, dX, dY, None, None
......@@ -422,9 +485,21 @@ class CSRMask(th.autograd.Function):
def gspmm(gidx, op, reduce_op, lhs_data, rhs_data):
if op == 'sub':
op = 'add'
rhs_data = -rhs_data
if op == 'div':
op = 'mul'
rhs_data = 1. / rhs_data
return GSpMM.apply(gidx, op, reduce_op, lhs_data, rhs_data)
def gsddmm(gidx, op, lhs_data, rhs_data, lhs_target='u', rhs_target='v'):
if op == 'sub':
op = 'add'
rhs_data = -rhs_data
if op == 'div':
op = 'mul'
rhs_data = 1. / rhs_data
return GSDDMM.apply(gidx, op, lhs_data, rhs_data, lhs_target, rhs_target)
def gspmm_hetero(g, op, reduce_op, *lhs_and_rhs_tuple):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment