Unverified Commit c801a164 authored by Zihao Ye's avatar Zihao Ye Committed by GitHub
Browse files

[hotfix] Allow broadcastable expand in the backward phase of gspmm (#1939)



* upd

* upd

* upd

* upd

* upd

* trigger

* simplify unittest

* patch-sp
Co-authored-by: default avatarJinjing Zhou <VoVAllen@users.noreply.github.com>
parent 929c99ed
...@@ -103,6 +103,13 @@ def _addsub(op, x): ...@@ -103,6 +103,13 @@ def _addsub(op, x):
return -x if op == 'sub' else x return -x if op == 'sub' else x
def _expand(x, shape):
padding_zeros = len(shape) + 1 - x.ndim
if padding_zeros > 0:
x = x.reshape((x.shape[0],) + (1,) * padding_zeros + x.shape[1:])
return x.broadcast_to((x.shape[0], *shape))
class GSpMM(mx.autograd.Function): class GSpMM(mx.autograd.Function):
def __init__(self, gidx, op, reduce_op): def __init__(self, gidx, op, reduce_op):
super(GSpMM, self).__init__() super(GSpMM, self).__init__()
...@@ -132,7 +139,7 @@ class GSpMM(mx.autograd.Function): ...@@ -132,7 +139,7 @@ class GSpMM(mx.autograd.Function):
if op in ['mul', 'div']: if op in ['mul', 'div']:
dX = _scatter_nd( dX = _scatter_nd(
argX, argX,
_muldiv(op, _gather_nd(argY, Y.broadcast_to((Y.shape[0], *dZ.shape[1:])))) * dZ, _muldiv(op, _gather_nd(argY, _expand(Y, dZ.shape[1:]))) * dZ,
X.shape[0]) X.shape[0])
elif op in ['add', 'sub', 'copy_lhs']: elif op in ['add', 'sub', 'copy_lhs']:
dX = _scatter_nd(argX, dZ, X.shape[0]) dX = _scatter_nd(argX, dZ, X.shape[0])
...@@ -153,7 +160,7 @@ class GSpMM(mx.autograd.Function): ...@@ -153,7 +160,7 @@ class GSpMM(mx.autograd.Function):
if op in ['mul', 'div']: if op in ['mul', 'div']:
dY = _scatter_nd( dY = _scatter_nd(
argY, argY,
_gather_nd(argX, X.broadcast_to((X.shape[0], *dZ.shape[1:]))) * dZ, _gather_nd(argX, _expand(X, dZ.shape[1:])) * dZ,
Y.shape[0]) Y.shape[0])
if op == 'div': if op == 'div':
dY = -dY / (Y ** 2) dY = -dY / (Y ** 2)
......
...@@ -53,6 +53,13 @@ def _addsub(op, x): ...@@ -53,6 +53,13 @@ def _addsub(op, x):
return -x if op == 'sub' else x return -x if op == 'sub' else x
def _expand(x, shape):
padding_zeros = len(shape) + 1 - x.ndim
if padding_zeros > 0:
x = x.view((x.shape[0],) + (1,) * padding_zeros + x.shape[1:])
return x.expand(-1, *shape)
class GSpMM(th.autograd.Function): class GSpMM(th.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, gidx, op, reduce_op, X, Y): def forward(ctx, gidx, op, reduce_op, X, Y):
...@@ -78,7 +85,7 @@ class GSpMM(th.autograd.Function): ...@@ -78,7 +85,7 @@ class GSpMM(th.autograd.Function):
dX = th.zeros((X.shape[0],) + dZ.shape[1:], dX = th.zeros((X.shape[0],) + dZ.shape[1:],
dtype=X.dtype, device=X.device) dtype=X.dtype, device=X.device)
if op in ['mul', 'div']: if op in ['mul', 'div']:
grad = _muldiv(op, Y.expand(-1, *dZ.shape[1:]).gather( grad = _muldiv(op, _expand(Y, dZ.shape[1:]).gather(
0, argY.long())) * dZ 0, argY.long())) * dZ
dX.scatter_add_(0, argX.long(), grad) dX.scatter_add_(0, argX.long(), grad)
elif op in ['add', 'sub', 'copy_lhs']: elif op in ['add', 'sub', 'copy_lhs']:
...@@ -99,8 +106,9 @@ class GSpMM(th.autograd.Function): ...@@ -99,8 +106,9 @@ class GSpMM(th.autograd.Function):
else: # max/min else: # max/min
dY = th.zeros((Y.shape[0],) + dZ.shape[1:], dY = th.zeros((Y.shape[0],) + dZ.shape[1:],
dtype=Y.dtype, device=Y.device) dtype=Y.dtype, device=Y.device)
print(X.shape, dZ.shape)
if op in ['mul', 'div']: if op in ['mul', 'div']:
grad = X.expand(-1, *dZ.shape[1:]).gather( grad = _expand(X, dZ.shape[1:]).gather(
0, argX.long()) * dZ 0, argX.long()) * dZ
dY.scatter_add_(0, argY.long(), grad) dY.scatter_add_(0, argY.long(), grad)
if op == 'div': if op == 'div':
...@@ -108,6 +116,7 @@ class GSpMM(th.autograd.Function): ...@@ -108,6 +116,7 @@ class GSpMM(th.autograd.Function):
elif op in ['add', 'sub', 'copy_rhs']: elif op in ['add', 'sub', 'copy_rhs']:
dY.scatter_add_(0, argY.long(), _addsub(op, dZ)) dY.scatter_add_(0, argY.long(), _addsub(op, dZ))
dY = _reduce_grad(dY, Y.shape) dY = _reduce_grad(dY, Y.shape)
print('jesus2')
else: # Y has no gradient else: # Y has no gradient
dY = None dY = None
return None, None, None, dX, dY return None, None, None, dX, dY
......
...@@ -39,7 +39,6 @@ def _gather_nd(index, src): ...@@ -39,7 +39,6 @@ def _gather_nd(index, src):
new_idx = index * stride + copy_to(sum(offsets), ctx) new_idx = index * stride + copy_to(sum(offsets), ctx)
src = tf.reshape(src, (-1,)) src = tf.reshape(src, (-1,))
new_idx = tf.reshape(new_idx, (-1)) new_idx = tf.reshape(new_idx, (-1))
print(src, new_idx)
rst = tf.reshape(tf.gather(src, new_idx), shp) rst = tf.reshape(tf.gather(src, new_idx), shp)
return rst return rst
...@@ -92,6 +91,13 @@ def _addsub(op, x): ...@@ -92,6 +91,13 @@ def _addsub(op, x):
return -x if op == 'sub' else x return -x if op == 'sub' else x
def _expand(x, shape):
padding_zeros = len(shape) + 1 - x.ndim
if padding_zeros > 0:
x = tf.reshape(x, (x.shape[0],) + (1,) * padding_zeros + x.shape[1:])
return tf.broadcast_to(x, (x.shape[0], *shape))
def gspmm_real(gidx, op, reduce_op, X, Y): def gspmm_real(gidx, op, reduce_op, X, Y):
out, (argX, argY) = _gspmm(gidx, op, reduce_op, X, Y) out, (argX, argY) = _gspmm(gidx, op, reduce_op, X, Y)
...@@ -110,7 +116,7 @@ def gspmm_real(gidx, op, reduce_op, X, Y): ...@@ -110,7 +116,7 @@ def gspmm_real(gidx, op, reduce_op, X, Y):
if op in ['mul', 'div']: if op in ['mul', 'div']:
dX = _scatter_nd( dX = _scatter_nd(
argX, argX,
_muldiv(op, _gather_nd(argY, tf.broadcast_to(Y, (Y.shape[0], *dZ.shape[1:])))) * dZ, _muldiv(op, _gather_nd(argY, _expand(Y, dZ.shape[1:]))) * dZ,
X.shape[0]) X.shape[0])
elif op in ['add', 'sub', 'copy_lhs']: elif op in ['add', 'sub', 'copy_lhs']:
dX = _scatter_nd(argX, dZ, X.shape[0]) dX = _scatter_nd(argX, dZ, X.shape[0])
...@@ -131,7 +137,7 @@ def gspmm_real(gidx, op, reduce_op, X, Y): ...@@ -131,7 +137,7 @@ def gspmm_real(gidx, op, reduce_op, X, Y):
if op in ['mul', 'div']: if op in ['mul', 'div']:
dY = _scatter_nd( dY = _scatter_nd(
argY, argY,
_gather_nd(argX, tf.broadcast_to(X, (X.shape[0], *dZ.shape[1:]))) * dZ, _gather_nd(argX, _expand(X, dZ.shape[1:])) * dZ,
Y.shape[0]) Y.shape[0])
if op == 'div': dY = -dY / (Y ** 2) if op == 'div': dY = -dY / (Y ** 2)
elif op in ['add', 'sub', 'copy_rhs']: elif op in ['add', 'sub', 'copy_rhs']:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment