"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "a8523bffa844752f8080e2ee675f91c32e392cf0"
Unverified Commit 76af2a2e authored by Zihao Ye's avatar Zihao Ye Committed by GitHub
Browse files

[perf] Remove activation cache if not required. (#3258)

* upd

* fix

* upd
parent ac01e880
...@@ -66,21 +66,54 @@ def _need_reduce_last_dim(ufeat, efeat): ...@@ -66,21 +66,54 @@ def _need_reduce_last_dim(ufeat, efeat):
"""Indicates whether to reduce the last dimension on edges """Indicates whether to reduce the last dimension on edges
in the backward pass of spmm, in the backward pass of spmm,
if so, use dot instead of mul.""" if so, use dot instead of mul."""
if ufeat is None or efeat is None:
return False
ushp = ufeat.shape ushp = ufeat.shape
eshp = efeat.shape eshp = efeat.shape
return ushp[1:-1] == eshp[1:-1] and eshp[-1] == 1 and ushp[-1] > 1 return ushp[1:-1] == eshp[1:-1] and eshp[-1] == 1 and ushp[-1] > 1
def _muldiv(op, x): def _expand(x, shape):
return 1. / x if op == 'div' else x return x.expand(-1, *shape)
def _addsub(op, x): def spmm_cache_X(binary_op, reduce_op, req_grad_X, req_grad_Y):
return -x if op == 'sub' else x """Rules to identify whether to cache X in SpMM forward stage."""
if binary_op != 'copy_lhs' and req_grad_Y:
if reduce_op == 'sum':
return True
else:
if binary_op == 'mul':
return True
return False
def _expand(x, shape): def spmm_cache_Y(binary_op, reduce_op, req_grad_X, req_grad_Y):
return x.expand(-1, *shape) """Rules to identify whether to cache Y in SpMM forward stage."""
if binary_op != 'copy_rhs' and req_grad_X:
if reduce_op == 'sum':
if binary_op in ['mul', 'add']:
return True
else:
if binary_op == 'mul':
return True
return False
def spmm_cache_argX(binary_op, reduce_op, req_grad_X, req_grad_Y):
"""Rules to identify whether to cache argX in SpMM forward stage."""
if req_grad_X:
if reduce_op in ['min', 'max']:
return True
return False
def spmm_cache_argY(binary_op, reduce_op, req_grad_X, req_grad_Y):
"""Rules to identify whether to cache argY in SpMM forward stage."""
if req_grad_Y:
if reduce_op in ['min', 'max']:
return True
return False
class GSpMM(th.autograd.Function): class GSpMM(th.autograd.Function):
...@@ -88,58 +121,69 @@ class GSpMM(th.autograd.Function): ...@@ -88,58 +121,69 @@ class GSpMM(th.autograd.Function):
@custom_fwd(cast_inputs=th.float16) @custom_fwd(cast_inputs=th.float16)
def forward(ctx, gidx, op, reduce_op, X, Y): def forward(ctx, gidx, op, reduce_op, X, Y):
out, (argX, argY) = _gspmm(gidx, op, reduce_op, X, Y) out, (argX, argY) = _gspmm(gidx, op, reduce_op, X, Y)
ctx.backward_cache = gidx, op, reduce_op reduce_last = _need_reduce_last_dim(X, Y)
X_shape = X.shape if X is not None else None
Y_shape = Y.shape if Y is not None else None
dtype = X.dtype if X is not None else Y.dtype
device = X.device if X is not None else Y.device
ctx.backward_cache = gidx, op, reduce_op, X_shape, Y_shape, dtype, device, reduce_last
req_grad_X = X.requires_grad if X is not None else False
req_grad_Y = Y.requires_grad if Y is not None else False
if not spmm_cache_X(op, reduce_op, req_grad_X, req_grad_Y):
X = None
if not spmm_cache_Y(op, reduce_op, req_grad_X, req_grad_Y):
Y = None
if not spmm_cache_argX(op, reduce_op, req_grad_X, req_grad_Y):
argX = None
if not spmm_cache_argY(op, reduce_op, req_grad_X, req_grad_Y):
argY = None
ctx.save_for_backward(X, Y, argX, argY) ctx.save_for_backward(X, Y, argX, argY)
return out return out
@staticmethod @staticmethod
@custom_bwd @custom_bwd
def backward(ctx, dZ): def backward(ctx, dZ):
gidx, op, reduce_op = ctx.backward_cache gidx, op, reduce_op, X_shape, Y_shape, dtype, device, reduce_last = ctx.backward_cache
X, Y, argX, argY = ctx.saved_tensors X, Y, argX, argY = ctx.saved_tensors
if op != 'copy_rhs' and ctx.needs_input_grad[3]: if op != 'copy_rhs' and ctx.needs_input_grad[3]:
g_rev = gidx.reverse() g_rev = gidx.reverse()
if reduce_op == 'sum': if reduce_op == 'sum':
if op in ['mul', 'div']: if op == 'mul':
dX = gspmm(g_rev, 'mul', 'sum', dZ, _muldiv(op, Y)) dX = gspmm(g_rev, 'mul', 'sum', dZ, Y)
elif op in ['add', 'sub']: elif op == 'add':
dX = gspmm(g_rev, 'copy_lhs', 'sum', dZ, Y) dX = gspmm(g_rev, 'copy_lhs', 'sum', dZ, Y)
elif op == 'copy_lhs': elif op == 'copy_lhs':
dX = gspmm(g_rev, 'copy_lhs', 'sum', dZ, None) dX = gspmm(g_rev, 'copy_lhs', 'sum', dZ, None)
else: # max/min else: # max/min
dX = th.zeros((X.shape[0],) + dZ.shape[1:], dX = th.zeros((X_shape[0],) + dZ.shape[1:],
dtype=X.dtype, device=X.device) dtype=dtype, device=device)
if op in ['mul', 'div']: if op == 'mul':
grad = _muldiv(op, _expand(Y, dZ.shape[1:]).gather( grad = _expand(Y, dZ.shape[1:]).gather(
0, argY.long())) * dZ 0, argY.long()) * dZ
dX.scatter_add_(0, argX.long(), grad) dX.scatter_add_(0, argX.long(), grad)
elif op in ['add', 'sub', 'copy_lhs']: elif op in ['add', 'copy_lhs']:
dX.scatter_add_(0, argX.long(), dZ) dX.scatter_add_(0, argX.long(), dZ)
dX = _reduce_grad(dX, X.shape) dX = _reduce_grad(dX, X_shape)
else: # X has not gradient else: # X has not gradient
dX = None dX = None
if op != 'copy_lhs' and ctx.needs_input_grad[4]: if op != 'copy_lhs' and ctx.needs_input_grad[4]:
if reduce_op == 'sum': if reduce_op == 'sum':
if op == 'mul' and _need_reduce_last_dim(X, Y): if op == 'mul' and reduce_last:
dY = gsddmm(gidx, 'dot', X, dZ) dY = gsddmm(gidx, 'dot', X, dZ)
elif op in ['mul', 'div']: elif op == 'mul':
dY = gsddmm(gidx, 'mul', X, dZ) dY = gsddmm(gidx, 'mul', X, dZ)
if op == 'div': elif op in ['add', 'copy_rhs']:
dY = -dY / (Y ** 2) dY = gsddmm(gidx, 'copy_rhs', X, dZ)
elif op in ['add', 'sub', 'copy_rhs']:
dY = gsddmm(gidx, 'copy_rhs', X, _addsub(op, dZ))
else: # max/min else: # max/min
dY = th.zeros((Y.shape[0],) + dZ.shape[1:], dY = th.zeros((Y_shape[0],) + dZ.shape[1:],
dtype=Y.dtype, device=Y.device) dtype=dtype, device=device)
if op in ['mul', 'div']: if op == 'mul':
grad = _expand(X, dZ.shape[1:]).gather( grad = _expand(X, dZ.shape[1:]).gather(
0, argX.long()) * dZ 0, argX.long()) * dZ
dY.scatter_add_(0, argY.long(), grad) dY.scatter_add_(0, argY.long(), grad)
if op == 'div': elif op in ['add', 'copy_rhs']:
dY = -dY / (Y ** 2) dY.scatter_add_(0, argY.long(), dZ)
elif op in ['add', 'sub', 'copy_rhs']: dY = _reduce_grad(dY, Y_shape)
dY.scatter_add_(0, argY.long(), _addsub(op, dZ))
dY = _reduce_grad(dY, Y.shape)
else: # Y has no gradient else: # Y has no gradient
dY = None dY = None
return None, None, None, dX, dY return None, None, None, dX, dY
...@@ -178,7 +222,7 @@ class GSpMM_hetero(th.autograd.Function): ...@@ -178,7 +222,7 @@ class GSpMM_hetero(th.autograd.Function):
# TODO(Israt): implement other combinations of message and reduce functions # TODO(Israt): implement other combinations of message and reduce functions
if reduce_op == 'sum': if reduce_op == 'sum':
if op in ['copy_rhs']: if op in ['copy_rhs']:
tmp_Z = tuple([_addsub(op, dZ[i]) if dZ[i] is not None else None tmp_Z = tuple([dZ[i] if dZ[i] is not None else None
for i in range(len(dZ))]) for i in range(len(dZ))])
tmp = tuple(X + tmp_Z) tmp = tuple(X + tmp_Z)
dY = gsddmm_hetero(g, 'copy_rhs', 'u', 'v', *tmp) dY = gsddmm_hetero(g, 'copy_rhs', 'u', 'v', *tmp)
...@@ -188,62 +232,81 @@ class GSpMM_hetero(th.autograd.Function): ...@@ -188,62 +232,81 @@ class GSpMM_hetero(th.autograd.Function):
dY = tuple([None] * len(Y)) dY = tuple([None] * len(Y))
return (None, None, None) + dX + dY return (None, None, None) + dX + dY
def sddmm_cache_X(op, req_grad_X, req_grad_Y):
"""Rules to identify whether to cache X in SDDMM forward stage."""
if op in ['mul', 'dot'] and req_grad_Y:
return True
return False
def sddmm_cache_Y(op, req_grad_X, req_grad_Y):
"""Rules to identify whether to cache Y in SDDMM forward stage."""
if op in ['mul', 'dot'] and req_grad_X:
return True
return False
class GSDDMM(th.autograd.Function): class GSDDMM(th.autograd.Function):
@staticmethod @staticmethod
@custom_fwd(cast_inputs=th.float16) @custom_fwd(cast_inputs=th.float16)
def forward(ctx, gidx, op, X, Y, lhs_target, rhs_target): def forward(ctx, gidx, op, X, Y, lhs_target, rhs_target):
out = _gsddmm(gidx, op, X, Y, lhs_target, rhs_target) out = _gsddmm(gidx, op, X, Y, lhs_target, rhs_target)
ctx.backward_cache = gidx, op, lhs_target, rhs_target X_shape = X.shape if X is not None else None
Y_shape = Y.shape if Y is not None else None
ctx.backward_cache = gidx, op, lhs_target, rhs_target, X_shape, Y_shape
req_grad_X = X.requires_grad if X is not None else False
req_grad_Y = Y.requires_grad if Y is not None else False
if not sddmm_cache_X(op, req_grad_X, req_grad_Y):
X = None
if not sddmm_cache_Y(op, req_grad_X, req_grad_Y):
Y = None
ctx.save_for_backward(X, Y) ctx.save_for_backward(X, Y)
return out return out
@staticmethod @staticmethod
@custom_bwd @custom_bwd
def backward(ctx, dZ): def backward(ctx, dZ):
gidx, op, lhs_target, rhs_target = ctx.backward_cache gidx, op, lhs_target, rhs_target, X_shape, Y_shape = ctx.backward_cache
X, Y = ctx.saved_tensors X, Y = ctx.saved_tensors
if op != 'copy_rhs' and ctx.needs_input_grad[2]: if op != 'copy_rhs' and ctx.needs_input_grad[2]:
if lhs_target in ['u', 'v']: if lhs_target in ['u', 'v']:
_gidx = gidx if lhs_target == 'v' else gidx.reverse() _gidx = gidx if lhs_target == 'v' else gidx.reverse()
if op in ['add', 'sub', 'copy_lhs']: if op in ['add', 'copy_lhs']:
dX = gspmm(_gidx, 'copy_rhs', 'sum', None, dZ) dX = gspmm(_gidx, 'copy_rhs', 'sum', None, dZ)
else: # mul, div, dot else: # mul, dot
if rhs_target == lhs_target: if rhs_target == lhs_target:
dX = gspmm(_gidx, 'copy_rhs', 'sum', None, dZ) * _muldiv(op, Y) dX = gspmm(_gidx, 'copy_rhs', 'sum', None, dZ) * Y
elif rhs_target == 'e': elif rhs_target == 'e':
dX = gspmm(_gidx, 'copy_rhs', 'sum', None, dZ * _muldiv(op, Y)) dX = gspmm(_gidx, 'copy_rhs', 'sum', None, dZ * Y)
else: # rhs_target = !lhs_target else: # rhs_target = !lhs_target
dX = gspmm(_gidx, 'mul', 'sum', _muldiv(op, Y), dZ) dX = gspmm(_gidx, 'mul', 'sum', Y, dZ)
else: # lhs_target == 'e' else: # lhs_target == 'e'
if op in ['add', 'sub', 'copy_lhs']: if op in ['add', 'copy_lhs']:
dX = dZ dX = dZ
else: # mul, div, dot else: # mul, dot
dX = gsddmm(gidx, 'mul', dZ, _muldiv(op, Y), 'e', rhs_target) dX = gsddmm(gidx, 'mul', dZ, Y, 'e', rhs_target)
dX = _reduce_grad(dX, X.shape) dX = _reduce_grad(dX, X_shape)
else: else:
dX = None dX = None
if op != 'copy_lhs' and ctx.needs_input_grad[3]: if op != 'copy_lhs' and ctx.needs_input_grad[3]:
if rhs_target in ['u', 'v']: if rhs_target in ['u', 'v']:
_gidx = gidx if rhs_target == 'v' else gidx.reverse() _gidx = gidx if rhs_target == 'v' else gidx.reverse()
if op in ['add', 'sub', 'copy_rhs']: if op in ['add', 'copy_rhs']:
dY = gspmm(_gidx, 'copy_rhs', 'sum', None, _addsub(op, dZ)) dY = gspmm(_gidx, 'copy_rhs', 'sum', None, dZ)
else: # mul, div, dot else: # mul, dot
if lhs_target == rhs_target: if lhs_target == rhs_target:
dY = gspmm(_gidx, 'copy_rhs', 'sum', None, dZ) * X dY = gspmm(_gidx, 'copy_rhs', 'sum', None, dZ) * X
elif lhs_target == 'e': elif lhs_target == 'e':
dY = gspmm(_gidx, 'copy_rhs', 'sum', None, dZ * X) dY = gspmm(_gidx, 'copy_rhs', 'sum', None, dZ * X)
else: # rhs_target = !lhs_target else: # rhs_target = !lhs_target
dY = gspmm(_gidx, 'mul', 'sum', X, dZ) dY = gspmm(_gidx, 'mul', 'sum', X, dZ)
if op == 'div':
dY = -dY / (Y ** 2)
else: else:
if op in ['add', 'sub', 'copy_rhs']: if op in ['add', 'copy_rhs']:
dY = _addsub(op, dZ) dY = dZ
else: # mul, div, dot else: # mul, dot
dY = gsddmm(gidx, 'mul', dZ, X, 'e', lhs_target) dY = gsddmm(gidx, 'mul', dZ, X, 'e', lhs_target)
if op == 'div': dY = _reduce_grad(dY, Y_shape)
dY = -dY / (Y ** 2)
dY = _reduce_grad(dY, Y.shape)
else: else:
dY = None dY = None
return None, None, dX, dY, None, None return None, None, dX, dY, None, None
...@@ -422,9 +485,21 @@ class CSRMask(th.autograd.Function): ...@@ -422,9 +485,21 @@ class CSRMask(th.autograd.Function):
def gspmm(gidx, op, reduce_op, lhs_data, rhs_data): def gspmm(gidx, op, reduce_op, lhs_data, rhs_data):
if op == 'sub':
op = 'add'
rhs_data = -rhs_data
if op == 'div':
op = 'mul'
rhs_data = 1. / rhs_data
return GSpMM.apply(gidx, op, reduce_op, lhs_data, rhs_data) return GSpMM.apply(gidx, op, reduce_op, lhs_data, rhs_data)
def gsddmm(gidx, op, lhs_data, rhs_data, lhs_target='u', rhs_target='v'): def gsddmm(gidx, op, lhs_data, rhs_data, lhs_target='u', rhs_target='v'):
if op == 'sub':
op = 'add'
rhs_data = -rhs_data
if op == 'div':
op = 'mul'
rhs_data = 1. / rhs_data
return GSDDMM.apply(gidx, op, lhs_data, rhs_data, lhs_target, rhs_target) return GSDDMM.apply(gidx, op, lhs_data, rhs_data, lhs_target, rhs_target)
def gspmm_hetero(g, op, reduce_op, *lhs_and_rhs_tuple): def gspmm_hetero(g, op, reduce_op, *lhs_and_rhs_tuple):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment