Unverified Commit a28bfa9f authored by Zihao Ye's avatar Zihao Ye Committed by GitHub
Browse files

[hotfix] Activate kernel unittest for tensorflow (#1895)



* upd

* upd

* upd

* upd

* upd

* trigger

* simplify unittest
Co-authored-by: default avatarJinjing Zhou <VoVAllen@users.noreply.github.com>
parent 18c960a1
...@@ -40,3 +40,25 @@ tips to check when you get an OOM error. ...@@ -40,3 +40,25 @@ tips to check when you get an OOM error.
* If your scenario does not require autograd, you can use ``inplace=True`` flag * If your scenario does not require autograd, you can use ``inplace=True`` flag
in the message passing APIs. This will update features inplacely that might in the message passing APIs. This will update features inplacely that might
save memory. save memory.
Reproducibility
---------------
Like PyTorch, we also do not guarantee completely reproducible results across multiple releases,
individual commits or different platforms.
However, we guarantee determinism on both CPU and GPU for most of the operators defined in ``dgl.ops`` (and
thus built-in message-passing functions) from DGL v0.5 on, this being said you will get exactly the same
output/gradients in multiple runs by fixing the random seed of Python, Numpy, and backend framework. You are
expected to get the same training loss/accuracy if your program only uses deterministic operators in backend
framework (for PyTorch, see https://pytorch.org/docs/stable/notes/randomness.html) and deterministic DGL
message-passing operators/functions.
For message-passing, we do not guarantee the determinism only in following cases:
1. The backward phase of Min/Max reduce function (we depend on ``scatter_add_`` operator in backend frameworks,
and it's not guaranteed to be deterministic).
2. Message Passing on ``DGLGraph``'s with restricted format ``COO`` (this will only happen when user specifies
``formats='coo'`` when creating the graph, normal users should not specify ``formats`` argument, which is
only designed for expert users to handle extremely large graph).
Note that though operators above are not deterministic, the difference across multiple runs is quite small.
...@@ -5,7 +5,9 @@ from ...sparse import _gspmm, _gsddmm ...@@ -5,7 +5,9 @@ from ...sparse import _gspmm, _gsddmm
from ...base import dgl_warning from ...base import dgl_warning
from .tensor import asnumpy, copy_to, zerocopy_from_numpy, context, to_backend_ctx from .tensor import asnumpy, copy_to, zerocopy_from_numpy, context, to_backend_ctx
def _scatter_nd(index, src, n_rows): def _scatter_nd(index, src, n_rows):
"""Similar to PyTorch's scatter nd on first dimension."""
assert index.shape == src.shape assert index.shape == src.shape
dgl_warning("MXNet do not support scatter_add, fallback to numpy.") dgl_warning("MXNet do not support scatter_add, fallback to numpy.")
ctx = context(src) ctx = context(src)
...@@ -30,7 +32,9 @@ def _scatter_nd(index, src, n_rows): ...@@ -30,7 +32,9 @@ def _scatter_nd(index, src, n_rows):
rst = copy_to(zerocopy_from_numpy(rst), ctx) rst = copy_to(zerocopy_from_numpy(rst), ctx)
return rst return rst
def _gather_nd(index, src): def _gather_nd(index, src):
"""Similar to PyTorch's gather nd on first dimension."""
ctx = context(src) ctx = context(src)
shp = index.shape shp = index.shape
ndim = src.ndim ndim = src.ndim
...@@ -48,6 +52,7 @@ def _gather_nd(index, src): ...@@ -48,6 +52,7 @@ def _gather_nd(index, src):
rst = nd.take(src, new_idx).reshape(shp) rst = nd.take(src, new_idx).reshape(shp)
return rst return rst
def _reduce_grad(grad, shape): def _reduce_grad(grad, shape):
"""Reduce gradient on the broadcast dimension """Reduce gradient on the broadcast dimension
If there is broadcast in forward pass, gradients need to be reduced on If there is broadcast in forward pass, gradients need to be reduced on
...@@ -80,6 +85,7 @@ def _reduce_grad(grad, shape): ...@@ -80,6 +85,7 @@ def _reduce_grad(grad, shape):
grad = grad.sum(axis=tuple(reduce_idx), keepdims=True) grad = grad.sum(axis=tuple(reduce_idx), keepdims=True)
return grad.reshape(shape) return grad.reshape(shape)
def _need_reduce_last_dim(ufeat, efeat): def _need_reduce_last_dim(ufeat, efeat):
"""Indicates whether to reduce the last dimension on edges """Indicates whether to reduce the last dimension on edges
in the backward pass of spmm, in the backward pass of spmm,
...@@ -88,12 +94,15 @@ def _need_reduce_last_dim(ufeat, efeat): ...@@ -88,12 +94,15 @@ def _need_reduce_last_dim(ufeat, efeat):
eshp = efeat.shape eshp = efeat.shape
return ushp[1:-1] == eshp[1:-1] and eshp[-1] == 1 and ushp[-1] > 1 return ushp[1:-1] == eshp[1:-1] and eshp[-1] == 1 and ushp[-1] > 1
def _muldiv(op, x): def _muldiv(op, x):
return 1. / x if op == 'div' else x return 1. / x if op == 'div' else x
def _addsub(op, x): def _addsub(op, x):
return -x if op == 'sub' else x return -x if op == 'sub' else x
class GSpMM(mx.autograd.Function): class GSpMM(mx.autograd.Function):
def __init__(self, gidx, op, reduce_op): def __init__(self, gidx, op, reduce_op):
super(GSpMM, self).__init__() super(GSpMM, self).__init__()
...@@ -136,7 +145,8 @@ class GSpMM(mx.autograd.Function): ...@@ -136,7 +145,8 @@ class GSpMM(mx.autograd.Function):
dY = _gsddmm(gidx, 'dot', X, dZ) dY = _gsddmm(gidx, 'dot', X, dZ)
elif op in ['mul', 'div']: elif op in ['mul', 'div']:
dY = _gsddmm(gidx, 'mul', X, dZ) dY = _gsddmm(gidx, 'mul', X, dZ)
if op == 'div': dY = -dY / (Y ** 2) if op == 'div':
dY = -dY / (Y ** 2)
elif op in ['add', 'sub', 'copy_rhs']: elif op in ['add', 'sub', 'copy_rhs']:
dY = _gsddmm(gidx, 'copy_rhs', X, _addsub(op, dZ)) dY = _gsddmm(gidx, 'copy_rhs', X, _addsub(op, dZ))
else: else:
...@@ -145,7 +155,8 @@ class GSpMM(mx.autograd.Function): ...@@ -145,7 +155,8 @@ class GSpMM(mx.autograd.Function):
argY, argY,
_gather_nd(argX, X.broadcast_to((X.shape[0], *dZ.shape[1:]))) * dZ, _gather_nd(argX, X.broadcast_to((X.shape[0], *dZ.shape[1:]))) * dZ,
Y.shape[0]) Y.shape[0])
if op == 'div': dY = -dY / (Y ** 2) if op == 'div':
dY = -dY / (Y ** 2)
elif op in ['add', 'sub', 'copy_rhs']: elif op in ['add', 'sub', 'copy_rhs']:
dY = _scatter_nd(argY, _addsub(op, dZ), Y.shape[0]) dY = _scatter_nd(argY, _addsub(op, dZ), Y.shape[0])
dY = _reduce_grad(dY, Y.shape) dY = _reduce_grad(dY, Y.shape)
...@@ -154,6 +165,7 @@ class GSpMM(mx.autograd.Function): ...@@ -154,6 +165,7 @@ class GSpMM(mx.autograd.Function):
self.saved_tensors = None self.saved_tensors = None
return dX, dY return dX, dY
def gspmm(gidx, op, reduce_op, lhs_data, rhs_data): def gspmm(gidx, op, reduce_op, lhs_data, rhs_data):
func = GSpMM(gidx, op, reduce_op) func = GSpMM(gidx, op, reduce_op)
ctx = to_backend_ctx(gidx.ctx) ctx = to_backend_ctx(gidx.ctx)
...@@ -163,6 +175,7 @@ def gspmm(gidx, op, reduce_op, lhs_data, rhs_data): ...@@ -163,6 +175,7 @@ def gspmm(gidx, op, reduce_op, lhs_data, rhs_data):
rhs_data = nd.zeros((1,), ctx=ctx) rhs_data = nd.zeros((1,), ctx=ctx)
return func(lhs_data, rhs_data) return func(lhs_data, rhs_data)
class GSDDMM(mx.autograd.Function): class GSDDMM(mx.autograd.Function):
def __init__(self, gidx, op, lhs_target, rhs_target): def __init__(self, gidx, op, lhs_target, rhs_target):
super(GSDDMM, self).__init__() super(GSDDMM, self).__init__()
...@@ -213,19 +226,22 @@ class GSDDMM(mx.autograd.Function): ...@@ -213,19 +226,22 @@ class GSDDMM(mx.autograd.Function):
dY = _gspmm(_gidx, 'copy_rhs', 'sum', None, dZ * X)[0] dY = _gspmm(_gidx, 'copy_rhs', 'sum', None, dZ * X)[0]
else: # rhs_target = !lhs_target else: # rhs_target = !lhs_target
dY = _gspmm(_gidx, 'mul', 'sum', X, dZ)[0] dY = _gspmm(_gidx, 'mul', 'sum', X, dZ)[0]
if op == 'div': dY = -dY / (Y ** 2) if op == 'div':
dY = -dY / (Y ** 2)
else: else:
if op in ['add', 'sub', 'copy_rhs']: if op in ['add', 'sub', 'copy_rhs']:
dY = _addsub(op, dZ) dY = _addsub(op, dZ)
else: # mul, div, dot else: # mul, div, dot
dY = _gsddmm(gidx, 'mul', dZ, X, 'e', lhs_target) dY = _gsddmm(gidx, 'mul', dZ, X, 'e', lhs_target)
if op == 'div': dY = -dY / (Y ** 2) if op == 'div':
dY = -dY / (Y ** 2)
dY = _reduce_grad(dY, Y.shape) dY = _reduce_grad(dY, Y.shape)
else: else:
dY = nd.zeros_like(Y) dY = nd.zeros_like(Y)
self.saved_tensors = None self.saved_tensors = None
return dX, dY return dX, dY
def gsddmm(gidx, op, lhs_data, rhs_data, lhs_target='u', rhs_target='v'): def gsddmm(gidx, op, lhs_data, rhs_data, lhs_target='u', rhs_target='v'):
func = GSDDMM(gidx, op, lhs_target, rhs_target) func = GSDDMM(gidx, op, lhs_target, rhs_target)
ctx = to_backend_ctx(gidx.ctx) ctx = to_backend_ctx(gidx.ctx)
......
...@@ -3,6 +3,7 @@ from ...sparse import _gspmm, _gsddmm ...@@ -3,6 +3,7 @@ from ...sparse import _gspmm, _gsddmm
__all__ = ['gspmm', 'gsddmm'] __all__ = ['gspmm', 'gsddmm']
def _reduce_grad(grad, shape): def _reduce_grad(grad, shape):
"""Reduce gradient on the broadcast dimension """Reduce gradient on the broadcast dimension
If there is broadcast in forward pass, gradients need to be reduced on If there is broadcast in forward pass, gradients need to be reduced on
...@@ -34,6 +35,7 @@ def _reduce_grad(grad, shape): ...@@ -34,6 +35,7 @@ def _reduce_grad(grad, shape):
grad = grad.sum(dim=tuple(reduce_idx), keepdim=True) grad = grad.sum(dim=tuple(reduce_idx), keepdim=True)
return grad.view(-1, *shape[1:]) return grad.view(-1, *shape[1:])
def _need_reduce_last_dim(ufeat, efeat): def _need_reduce_last_dim(ufeat, efeat):
"""Indicates whether to reduce the last dimension on edges """Indicates whether to reduce the last dimension on edges
in the backward pass of spmm, in the backward pass of spmm,
...@@ -42,12 +44,15 @@ def _need_reduce_last_dim(ufeat, efeat): ...@@ -42,12 +44,15 @@ def _need_reduce_last_dim(ufeat, efeat):
eshp = efeat.shape eshp = efeat.shape
return ushp[1:-1] == eshp[1:-1] and eshp[-1] == 1 and ushp[-1] > 1 return ushp[1:-1] == eshp[1:-1] and eshp[-1] == 1 and ushp[-1] > 1
def _muldiv(op, x): def _muldiv(op, x):
return 1. / x if op == 'div' else x return 1. / x if op == 'div' else x
def _addsub(op, x): def _addsub(op, x):
return -x if op == 'sub' else x return -x if op == 'sub' else x
class GSpMM(th.autograd.Function): class GSpMM(th.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, gidx, op, reduce_op, X, Y): def forward(ctx, gidx, op, reduce_op, X, Y):
...@@ -69,15 +74,17 @@ class GSpMM(th.autograd.Function): ...@@ -69,15 +74,17 @@ class GSpMM(th.autograd.Function):
dX = gspmm(g_rev, 'copy_lhs', 'sum', dZ, Y) dX = gspmm(g_rev, 'copy_lhs', 'sum', dZ, Y)
elif op == 'copy_lhs': elif op == 'copy_lhs':
dX = gspmm(g_rev, 'copy_lhs', 'sum', dZ, None) dX = gspmm(g_rev, 'copy_lhs', 'sum', dZ, None)
else: else: # max/min
dX = th.zeros((X.shape[0],) + dZ.shape[1:], dtype=X.dtype, device=X.device) dX = th.zeros((X.shape[0],) + dZ.shape[1:],
dtype=X.dtype, device=X.device)
if op in ['mul', 'div']: if op in ['mul', 'div']:
dX.scatter_add_(0, argX.long(), grad = _muldiv(op, Y.expand(-1, *dZ.shape[1:]).gather(
_muldiv(op, Y.expand(-1, *dZ.shape[1:]).gather(0, argY.long())) * dZ) 0, argY.long())) * dZ
dX.scatter_add_(0, argX.long(), grad)
elif op in ['add', 'sub', 'copy_lhs']: elif op in ['add', 'sub', 'copy_lhs']:
dX.scatter_add_(0, argX.long(), dZ) dX.scatter_add_(0, argX.long(), dZ)
dX = _reduce_grad(dX, X.shape) dX = _reduce_grad(dX, X.shape)
else: else: # X has not gradient
dX = None dX = None
if op != 'copy_lhs' and ctx.needs_input_grad[4]: if op != 'copy_lhs' and ctx.needs_input_grad[4]:
if reduce_op == 'sum': if reduce_op == 'sum':
...@@ -85,22 +92,27 @@ class GSpMM(th.autograd.Function): ...@@ -85,22 +92,27 @@ class GSpMM(th.autograd.Function):
dY = gsddmm(gidx, 'dot', X, dZ) dY = gsddmm(gidx, 'dot', X, dZ)
elif op in ['mul', 'div']: elif op in ['mul', 'div']:
dY = gsddmm(gidx, 'mul', X, dZ) dY = gsddmm(gidx, 'mul', X, dZ)
if op == 'div': dY = -dY / (Y ** 2) if op == 'div':
dY = -dY / (Y ** 2)
elif op in ['add', 'sub', 'copy_rhs']: elif op in ['add', 'sub', 'copy_rhs']:
dY = gsddmm(gidx, 'copy_rhs', X, _addsub(op, dZ)) dY = gsddmm(gidx, 'copy_rhs', X, _addsub(op, dZ))
else: else: # max/min
dY = th.zeros((Y.shape[0],) + dZ.shape[1:], dtype=Y.dtype, device=Y.device) dY = th.zeros((Y.shape[0],) + dZ.shape[1:],
dtype=Y.dtype, device=Y.device)
if op in ['mul', 'div']: if op in ['mul', 'div']:
dY.scatter_add_(0, argY.long(), grad = X.expand(-1, *dZ.shape[1:]).gather(
X.expand(-1, *dZ.shape[1:]).gather(0, argX.long()) * dZ) 0, argX.long()) * dZ
if op == 'div': dY = -dY / (Y ** 2) dY.scatter_add_(0, argY.long(), grad)
if op == 'div':
dY = -dY / (Y ** 2)
elif op in ['add', 'sub', 'copy_rhs']: elif op in ['add', 'sub', 'copy_rhs']:
dY.scatter_add_(0, argY.long(), _addsub(op, dZ)) dY.scatter_add_(0, argY.long(), _addsub(op, dZ))
dY = _reduce_grad(dY, Y.shape) dY = _reduce_grad(dY, Y.shape)
else: else: # Y has no gradient
dY = None dY = None
return None, None, None, dX, dY return None, None, None, dX, dY
class GSDDMM(th.autograd.Function): class GSDDMM(th.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, gidx, op, X, Y, lhs_target, rhs_target): def forward(ctx, gidx, op, X, Y, lhs_target, rhs_target):
...@@ -159,9 +171,10 @@ class GSDDMM(th.autograd.Function): ...@@ -159,9 +171,10 @@ class GSDDMM(th.autograd.Function):
dY = None dY = None
return None, None, dX, dY, None, None return None, None, dX, dY, None, None
def gspmm(gidx, op, reduce_op, lhs_data, rhs_data): def gspmm(gidx, op, reduce_op, lhs_data, rhs_data):
return GSpMM.apply(gidx, op, reduce_op, lhs_data, rhs_data) return GSpMM.apply(gidx, op, reduce_op, lhs_data, rhs_data)
def gsddmm(gidx, op, lhs_data, rhs_data, lhs_target='u', rhs_target='v'): def gsddmm(gidx, op, lhs_data, rhs_data, lhs_target='u', rhs_target='v'):
return GSDDMM.apply(gidx, op, lhs_data, rhs_data, lhs_target, rhs_target) return GSDDMM.apply(gidx, op, lhs_data, rhs_data, lhs_target, rhs_target)
...@@ -3,6 +3,7 @@ import numpy as np ...@@ -3,6 +3,7 @@ import numpy as np
from .tensor import tensor, copy_to, context from .tensor import tensor, copy_to, context
from ...sparse import _gspmm, _gsddmm from ...sparse import _gspmm, _gsddmm
def _scatter_nd(index, src, n_rows): def _scatter_nd(index, src, n_rows):
assert index.shape == src.shape assert index.shape == src.shape
shp = index.shape shp = index.shape
...@@ -22,6 +23,7 @@ def _scatter_nd(index, src, n_rows): ...@@ -22,6 +23,7 @@ def _scatter_nd(index, src, n_rows):
rst = tf.reshape(tf.scatter_nd(new_idx, src, (stride * n_rows,)), (n_rows, *shp[1:])) rst = tf.reshape(tf.scatter_nd(new_idx, src, (stride * n_rows,)), (n_rows, *shp[1:]))
return rst return rst
def _gather_nd(index, src): def _gather_nd(index, src):
shp = index.shape shp = index.shape
ctx = context(src) ctx = context(src)
...@@ -41,6 +43,7 @@ def _gather_nd(index, src): ...@@ -41,6 +43,7 @@ def _gather_nd(index, src):
rst = tf.reshape(tf.gather(src, new_idx), shp) rst = tf.reshape(tf.gather(src, new_idx), shp)
return rst return rst
def _reduce_grad(grad, shape): def _reduce_grad(grad, shape):
"""Reduce gradient on the broadcast dimension """Reduce gradient on the broadcast dimension
If there is broadcast in forward pass, gradients need to be reduced on If there is broadcast in forward pass, gradients need to be reduced on
...@@ -71,6 +74,7 @@ def _reduce_grad(grad, shape): ...@@ -71,6 +74,7 @@ def _reduce_grad(grad, shape):
grad = tf.reduce_sum(grad, axis=reduce_idx_tensor, keepdims=True) grad = tf.reduce_sum(grad, axis=reduce_idx_tensor, keepdims=True)
return tf.reshape(grad, shape) return tf.reshape(grad, shape)
def _need_reduce_last_dim(ufeat, efeat): def _need_reduce_last_dim(ufeat, efeat):
"""Indicates whether to reduce the last dimension on edges """Indicates whether to reduce the last dimension on edges
in the backward pass of spmm, in the backward pass of spmm,
...@@ -79,12 +83,15 @@ def _need_reduce_last_dim(ufeat, efeat): ...@@ -79,12 +83,15 @@ def _need_reduce_last_dim(ufeat, efeat):
eshp = efeat.shape eshp = efeat.shape
return ushp[1:-1] == eshp[1:-1] and eshp[-1] == 1 and ushp[-1] > 1 return ushp[1:-1] == eshp[1:-1] and eshp[-1] == 1 and ushp[-1] > 1
def _muldiv(op, x): def _muldiv(op, x):
return 1. / x if op == 'div' else x return 1. / x if op == 'div' else x
def _addsub(op, x): def _addsub(op, x):
return -x if op == 'sub' else x return -x if op == 'sub' else x
def gspmm_real(gidx, op, reduce_op, X, Y): def gspmm_real(gidx, op, reduce_op, X, Y):
out, (argX, argY) = _gspmm(gidx, op, reduce_op, X, Y) out, (argX, argY) = _gspmm(gidx, op, reduce_op, X, Y)
...@@ -135,6 +142,7 @@ def gspmm_real(gidx, op, reduce_op, X, Y): ...@@ -135,6 +142,7 @@ def gspmm_real(gidx, op, reduce_op, X, Y):
return dX, dY return dX, dY
return out, grad return out, grad
def gspmm(gidx, op, reduce_op, X, Y): def gspmm(gidx, op, reduce_op, X, Y):
@tf.custom_gradient @tf.custom_gradient
def _lambda(X, Y): def _lambda(X, Y):
...@@ -145,6 +153,7 @@ def gspmm(gidx, op, reduce_op, X, Y): ...@@ -145,6 +153,7 @@ def gspmm(gidx, op, reduce_op, X, Y):
Y = tf.zeros(()) Y = tf.zeros(())
return _lambda(X, Y) return _lambda(X, Y)
def gsddmm_real(gidx, op, X, Y, lhs_target, rhs_target): def gsddmm_real(gidx, op, X, Y, lhs_target, rhs_target):
out = _gsddmm(gidx, op, X, Y, lhs_target, rhs_target) out = _gsddmm(gidx, op, X, Y, lhs_target, rhs_target)
...@@ -181,19 +190,22 @@ def gsddmm_real(gidx, op, X, Y, lhs_target, rhs_target): ...@@ -181,19 +190,22 @@ def gsddmm_real(gidx, op, X, Y, lhs_target, rhs_target):
dY = _gspmm(_gidx, 'copy_rhs', 'sum', None, dZ * X)[0] dY = _gspmm(_gidx, 'copy_rhs', 'sum', None, dZ * X)[0]
else: # rhs_target = !lhs_target else: # rhs_target = !lhs_target
dY = _gspmm(_gidx, 'mul', 'sum', X, dZ)[0] dY = _gspmm(_gidx, 'mul', 'sum', X, dZ)[0]
if op == 'div': dY = -dY / (Y ** 2) if op == 'div':
dY = -dY / (Y ** 2)
else: else:
if op in ['add', 'sub', 'copy_rhs']: if op in ['add', 'sub', 'copy_rhs']:
dY = _addsub(op, dZ) dY = _addsub(op, dZ)
else: # mul, div, dot else: # mul, div, dot
dY = _gsddmm(gidx, 'mul', dZ, X, 'e', lhs_target) dY = _gsddmm(gidx, 'mul', dZ, X, 'e', lhs_target)
if op == 'div': dY = -dY / (Y ** 2) if op == 'div':
dY = -dY / (Y ** 2)
dY = _reduce_grad(dY, Y.shape) dY = _reduce_grad(dY, Y.shape)
else: else:
dY = tf.zeros_like(Y) dY = tf.zeros_like(Y)
return dX, dY return dX, dY
return out, grad return out, grad
def gsddmm(gidx, op, X, Y, lhs_target='u', rhs_target='v'): def gsddmm(gidx, op, X, Y, lhs_target='u', rhs_target='v'):
@tf.custom_gradient @tf.custom_gradient
def _lambda(X, Y): def _lambda(X, Y):
......
...@@ -43,6 +43,7 @@ else: ...@@ -43,6 +43,7 @@ else:
def zerocopy_from_dlpack(input): def zerocopy_from_dlpack(input):
return tfdlpack.from_dlpack(input) return tfdlpack.from_dlpack(input)
def data_type_dict(): def data_type_dict():
return {'float16': tf.float16, return {'float16': tf.float16,
'float32': tf.float32, 'float32': tf.float32,
...@@ -54,9 +55,11 @@ def data_type_dict(): ...@@ -54,9 +55,11 @@ def data_type_dict():
'int64': tf.int64, 'int64': tf.int64,
'bool' : tf.bool} 'bool' : tf.bool}
def cpu(): def cpu():
return "/cpu:0" return "/cpu:0"
def tensor(data, dtype=None): def tensor(data, dtype=None):
if isinstance(data, tf.Tensor): if isinstance(data, tf.Tensor):
if dtype is None or data.dtype == dtype: if dtype is None or data.dtype == dtype:
...@@ -68,13 +71,16 @@ def tensor(data, dtype=None): ...@@ -68,13 +71,16 @@ def tensor(data, dtype=None):
data = [data] data = [data]
return tf.convert_to_tensor(data, dtype=dtype) return tf.convert_to_tensor(data, dtype=dtype)
def initialize_context(): def initialize_context():
tf.zeros(1) tf.zeros(1)
def as_scalar(data): def as_scalar(data):
data = data.numpy() data = data.numpy()
return data if np.isscalar(data) else data.item() return data if np.isscalar(data) else data.item()
def get_preferred_sparse_format(): def get_preferred_sparse_format():
"""Get the preferred sparse matrix format supported by the backend. """Get the preferred sparse matrix format supported by the backend.
...@@ -127,6 +133,7 @@ def device_type(ctx): ...@@ -127,6 +133,7 @@ def device_type(ctx):
def device_id(ctx): def device_id(ctx):
return tf.DeviceSpec.from_string(ctx).device_index return tf.DeviceSpec.from_string(ctx).device_index
def to_backend_ctx(dglctx): def to_backend_ctx(dglctx):
dev_type = dglctx.device_type dev_type = dglctx.device_type
if dev_type == 1: if dev_type == 1:
...@@ -136,6 +143,7 @@ def to_backend_ctx(dglctx): ...@@ -136,6 +143,7 @@ def to_backend_ctx(dglctx):
else: else:
raise ValueError('Unsupported DGL device context:', dglctx) raise ValueError('Unsupported DGL device context:', dglctx)
def astype(input, ty): def astype(input, ty):
return tf.cast(input, dtype=ty) return tf.cast(input, dtype=ty)
...@@ -226,9 +234,11 @@ def argtopk(input, k, dim, descending=True): ...@@ -226,9 +234,11 @@ def argtopk(input, k, dim, descending=True):
def exp(input): def exp(input):
return tf.exp(input) return tf.exp(input)
def sqrt(input): def sqrt(input):
return tf.sqrt(input) return tf.sqrt(input)
def softmax(input, dim=-1): def softmax(input, dim=-1):
return tf.math.softmax(input, axis=dim) return tf.math.softmax(input, axis=dim)
...@@ -275,6 +285,7 @@ def scatter_row(data, row_index, value): ...@@ -275,6 +285,7 @@ def scatter_row(data, row_index, value):
row_index = tf.expand_dims(row_index, 1) row_index = tf.expand_dims(row_index, 1)
return tf.tensor_scatter_nd_update(data, row_index, value) return tf.tensor_scatter_nd_update(data, row_index, value)
def index_add_inplace(data, row_idx, value): def index_add_inplace(data, row_idx, value):
raise NotImplementedError("Tensorflow doesn't support inplace index_add") raise NotImplementedError("Tensorflow doesn't support inplace index_add")
...@@ -436,9 +447,11 @@ def zerocopy_to_dgl_ndarray(data): ...@@ -436,9 +447,11 @@ def zerocopy_to_dgl_ndarray(data):
else: else:
return nd.from_dlpack(zerocopy_to_dlpack(data)) return nd.from_dlpack(zerocopy_to_dlpack(data))
def zerocopy_to_dgl_ndarray_for_write(input): def zerocopy_to_dgl_ndarray_for_write(input):
return zerocopy_to_dgl_ndarray(input) return zerocopy_to_dgl_ndarray(input)
def zerocopy_from_dgl_ndarray(input): def zerocopy_from_dgl_ndarray(input):
return zerocopy_from_dlpack(input.to_dlpack()) return zerocopy_from_dlpack(input.to_dlpack())
...@@ -575,6 +588,7 @@ def copy_reduce_real(reducer, graph, target, in_data, out_size, in_map, ...@@ -575,6 +588,7 @@ def copy_reduce_real(reducer, graph, target, in_data, out_size, in_map,
return grad_in return grad_in
return out_data, grad return out_data, grad
def _reduce_grad(grad, shape): def _reduce_grad(grad, shape):
"""Reduce gradient on the broadcast dimension """Reduce gradient on the broadcast dimension
...@@ -613,6 +627,7 @@ def sync(): ...@@ -613,6 +627,7 @@ def sync():
context = context().context() context = context().context()
context.async_wait() context.async_wait()
class GradContext: class GradContext:
def __init__(self): def __init__(self):
self.tensor_for_grad = [] self.tensor_for_grad = []
...@@ -693,9 +708,11 @@ def backward(x, head_gradient=None): ...@@ -693,9 +708,11 @@ def backward(x, head_gradient=None):
def grad(x): def grad(x):
return cgrad.grad(x) return cgrad.grad(x)
def is_no_grad(x): def is_no_grad(x):
return cgrad.is_no_grad(x) return cgrad.is_no_grad(x)
def is_recording(): def is_recording():
raise NotImplementedError("Tensorflow doesn't support is_recording") raise NotImplementedError("Tensorflow doesn't support is_recording")
......
...@@ -6,6 +6,7 @@ from ..backend import gsddmm as gsddmm_internal ...@@ -6,6 +6,7 @@ from ..backend import gsddmm as gsddmm_internal
__all__ = ['gsddmm', 'copy_u', 'copy_v'] __all__ = ['gsddmm', 'copy_u', 'copy_v']
def gsddmm(g, op, lhs_data, rhs_data, lhs_target='u', rhs_target='v'): def gsddmm(g, op, lhs_data, rhs_data, lhs_target='u', rhs_target='v'):
r""" Generalized Sampled-Dense-Dense Matrix Multiplication interface. r""" Generalized Sampled-Dense-Dense Matrix Multiplication interface.
It computes edge features by :attr:`op` lhs features and rhs features. It computes edge features by :attr:`op` lhs features and rhs features.
...@@ -43,6 +44,7 @@ def gsddmm(g, op, lhs_data, rhs_data, lhs_target='u', rhs_target='v'): ...@@ -43,6 +44,7 @@ def gsddmm(g, op, lhs_data, rhs_data, lhs_target='u', rhs_target='v'):
return gsddmm_internal( return gsddmm_internal(
g._graph, op, lhs_data, rhs_data, lhs_target, rhs_target) g._graph, op, lhs_data, rhs_data, lhs_target, rhs_target)
def _gen_sddmm_func(lhs_target, rhs_target, binary_op): def _gen_sddmm_func(lhs_target, rhs_target, binary_op):
name = "{}_{}_{}".format(lhs_target, binary_op, rhs_target) name = "{}_{}_{}".format(lhs_target, binary_op, rhs_target)
target_dict = { target_dict = {
...@@ -87,6 +89,7 @@ def _gen_sddmm_func(lhs_target, rhs_target, binary_op): ...@@ -87,6 +89,7 @@ def _gen_sddmm_func(lhs_target, rhs_target, binary_op):
func.__doc__ = docstring func.__doc__ = docstring
return func return func
def _register_sddmm_func(): def _register_sddmm_func():
"""Register sddmm functions""" """Register sddmm functions"""
target = ["u", "v", "e"] target = ["u", "v", "e"]
...@@ -97,6 +100,7 @@ def _register_sddmm_func(): ...@@ -97,6 +100,7 @@ def _register_sddmm_func():
setattr(sys.modules[__name__], func.__name__, func) setattr(sys.modules[__name__], func.__name__, func)
__all__.append(func.__name__) __all__.append(func.__name__)
def copy_u(g, x): def copy_u(g, x):
r"""Generalized SDDMM function that copies source node features to edges. r"""Generalized SDDMM function that copies source node features to edges.
...@@ -118,6 +122,7 @@ def copy_u(g, x): ...@@ -118,6 +122,7 @@ def copy_u(g, x):
""" """
return gsddmm(g, 'copy_lhs', x, None) return gsddmm(g, 'copy_lhs', x, None)
def copy_v(g, x): def copy_v(g, x):
r"""Generalized SDDMM function that copies destination node features to edges. r"""Generalized SDDMM function that copies destination node features to edges.
...@@ -139,4 +144,5 @@ def copy_v(g, x): ...@@ -139,4 +144,5 @@ def copy_v(g, x):
""" """
return gsddmm(g, 'copy_rhs', None, x) return gsddmm(g, 'copy_rhs', None, x)
_register_sddmm_func() _register_sddmm_func()
...@@ -5,6 +5,7 @@ from .. import backend as F ...@@ -5,6 +5,7 @@ from .. import backend as F
from .. import convert from .. import convert
from .. import function as fn from .. import function as fn
def segment_reduce(seglen, value, reducer='sum'): def segment_reduce(seglen, value, reducer='sum'):
"""Segment reduction operator. """Segment reduction operator.
...@@ -54,6 +55,7 @@ def segment_reduce(seglen, value, reducer='sum'): ...@@ -54,6 +55,7 @@ def segment_reduce(seglen, value, reducer='sum'):
g.update_all(fn.copy_u('h', 'm'), getattr(fn, reducer)('m', 'h')) g.update_all(fn.copy_u('h', 'm'), getattr(fn, reducer)('m', 'h'))
return g.dstdata['h'] return g.dstdata['h']
def segment_softmax(seglen, value): def segment_softmax(seglen, value):
"""Performa softmax on each segment. """Performa softmax on each segment.
......
...@@ -43,6 +43,7 @@ def gspmm(g, op, reduce_op, lhs_data, rhs_data): ...@@ -43,6 +43,7 @@ def gspmm(g, op, reduce_op, lhs_data, rhs_data):
""" """
return gspmm_internal(g._graph, op, reduce_op, lhs_data, rhs_data) return gspmm_internal(g._graph, op, reduce_op, lhs_data, rhs_data)
def _gen_spmm_func(binary_op, reduce_op): def _gen_spmm_func(binary_op, reduce_op):
name = "u_{}_e_{}".format(binary_op, reduce_op) name = "u_{}_e_{}".format(binary_op, reduce_op)
docstring = """Generalized SpMM function. docstring = """Generalized SpMM function.
...@@ -82,6 +83,7 @@ def _gen_spmm_func(binary_op, reduce_op): ...@@ -82,6 +83,7 @@ def _gen_spmm_func(binary_op, reduce_op):
func.__doc__ = docstring func.__doc__ = docstring
return func return func
def _gen_copy_reduce_func(binary_op, reduce_op): def _gen_copy_reduce_func(binary_op, reduce_op):
name = "{}_{}".format(binary_op, reduce_op) name = "{}_{}".format(binary_op, reduce_op)
...@@ -126,6 +128,7 @@ def _gen_copy_reduce_func(binary_op, reduce_op): ...@@ -126,6 +128,7 @@ def _gen_copy_reduce_func(binary_op, reduce_op):
func.__doc__ = docstring(binary_op) func.__doc__ = docstring(binary_op)
return func return func
def _register_spmm_func(): def _register_spmm_func():
"""Register spmm functions""" """Register spmm functions"""
for binary_op in ["add", "sub", "mul", "div", "copy_u", "copy_e"]: for binary_op in ["add", "sub", "mul", "div", "copy_u", "copy_e"]:
...@@ -137,4 +140,5 @@ def _register_spmm_func(): ...@@ -137,4 +140,5 @@ def _register_spmm_func():
setattr(sys.modules[__name__], func.__name__, func) setattr(sys.modules[__name__], func.__name__, func)
__all__.append(func.__name__) __all__.append(func.__name__)
_register_spmm_func() _register_spmm_func()
...@@ -7,6 +7,7 @@ from ._ffi.function import _init_api ...@@ -7,6 +7,7 @@ from ._ffi.function import _init_api
from .base import DGLError from .base import DGLError
from . import backend as F from . import backend as F
def infer_broadcast_shape(op, shp1, shp2): def infer_broadcast_shape(op, shp1, shp2):
r"""Check the shape validity, and infer the output shape given input shape and operator. r"""Check the shape validity, and infer the output shape given input shape and operator.
Note the both :attr:`shp1`, :attr:`shp2` and the returned shape are feature Note the both :attr:`shp1`, :attr:`shp2` and the returned shape are feature
...@@ -53,14 +54,17 @@ def infer_broadcast_shape(op, shp1, shp2): ...@@ -53,14 +54,17 @@ def infer_broadcast_shape(op, shp1, shp2):
rst = tuple(max(d1, d2) for d1, d2 in zip(pad_shp1, pad_shp2)) rst = tuple(max(d1, d2) for d1, d2 in zip(pad_shp1, pad_shp2))
return rst[:-1] + (1,) if op == "dot" else rst return rst[:-1] + (1,) if op == "dot" else rst
def to_dgl_nd(x): def to_dgl_nd(x):
"""Convert framework-specific tensor/None to dgl ndarray.""" """Convert framework-specific tensor/None to dgl ndarray."""
return nd.NULL['int64'] if x is None else F.zerocopy_to_dgl_ndarray(x) return nd.NULL['int64'] if x is None else F.zerocopy_to_dgl_ndarray(x)
def to_dgl_nd_for_write(x): def to_dgl_nd_for_write(x):
"""Convert framework-specific tensor/None to dgl ndarray for write.""" """Convert framework-specific tensor/None to dgl ndarray for write."""
return nd.NULL['int64'] if x is None else F.zerocopy_to_dgl_ndarray_for_write(x) return nd.NULL['int64'] if x is None else F.zerocopy_to_dgl_ndarray_for_write(x)
target_mapping = { target_mapping = {
'u': 0, 'u': 0,
'e': 1, 'e': 1,
...@@ -70,6 +74,7 @@ target_mapping = { ...@@ -70,6 +74,7 @@ target_mapping = {
'dst': 2 'dst': 2
} }
def _gspmm(gidx, op, reduce_op, u, e): def _gspmm(gidx, op, reduce_op, u, e):
r""" Generalized Sparse Matrix Multiplication interface. It takes the result of r""" Generalized Sparse Matrix Multiplication interface. It takes the result of
:attr:`op` on source node feature and edge feature, leads to a message on edge. :attr:`op` on source node feature and edge feature, leads to a message on edge.
...@@ -141,17 +146,30 @@ def _gspmm(gidx, op, reduce_op, u, e): ...@@ -141,17 +146,30 @@ def _gspmm(gidx, op, reduce_op, u, e):
arg_u = F.zeros(v_shp, idtype, ctx) arg_u = F.zeros(v_shp, idtype, ctx)
if use_e: if use_e:
arg_e = F.zeros(v_shp, idtype, ctx) arg_e = F.zeros(v_shp, idtype, ctx)
arg_u_nd = to_dgl_nd_for_write(arg_u)
arg_e_nd = to_dgl_nd_for_write(arg_e)
if gidx.number_of_edges(0) > 0: if gidx.number_of_edges(0) > 0:
_CAPI_DGLKernelSpMM(gidx, op, reduce_op, _CAPI_DGLKernelSpMM(gidx, op, reduce_op,
to_dgl_nd(u if use_u else None), to_dgl_nd(u if use_u else None),
to_dgl_nd(e if use_e else None), to_dgl_nd(e if use_e else None),
to_dgl_nd_for_write(v), to_dgl_nd_for_write(v),
to_dgl_nd_for_write(arg_u), arg_u_nd,
to_dgl_nd_for_write(arg_e)) arg_e_nd)
# NOTE(zihao): actually we can avoid the following step, because arg_*_nd
# refers to the data that stores arg_*. After we call _CAPI_DGLKernelSpMM,
# arg_* should have already been changed. But we found this doesn't work
# under Tensorflow when index type is int32. (arg_u and arg_e would be
# all zero).
# The workaround is proposed by Jinjing, and we still need to investigate
# where the problem is.
arg_u = None if arg_u is None else F.zerocopy_from_dgl_ndarray(arg_u_nd)
arg_e = None if arg_e is None else F.zerocopy_from_dgl_ndarray(arg_e_nd)
# To deal with scalar node/edge features.
if (expand_u or not use_u) and (expand_e or not use_e): if (expand_u or not use_u) and (expand_e or not use_e):
v = F.squeeze(v, -1) v = F.squeeze(v, -1)
return v, (arg_u, arg_e) return v, (arg_u, arg_e)
def _gsddmm(gidx, op, lhs, rhs, lhs_target='u', rhs_target='v'): def _gsddmm(gidx, op, lhs, rhs, lhs_target='u', rhs_target='v'):
r""" Generalized Sampled-Dense-Dense Matrix Multiplication interface. It r""" Generalized Sampled-Dense-Dense Matrix Multiplication interface. It
takes the result of :attr:`op` on source node feature and destination node takes the result of :attr:`op` on source node feature and destination node
...@@ -225,4 +243,5 @@ def _gsddmm(gidx, op, lhs, rhs, lhs_target='u', rhs_target='v'): ...@@ -225,4 +243,5 @@ def _gsddmm(gidx, op, lhs, rhs, lhs_target='u', rhs_target='v'):
out = F.squeeze(out, -1) out = F.squeeze(out, -1)
return out return out
_init_api("dgl.sparse") _init_api("dgl.sparse")
...@@ -70,15 +70,12 @@ udf_reduce = { ...@@ -70,15 +70,12 @@ udf_reduce = {
graphs = [ graphs = [
# dgl.rand_graph(30, 0), # dgl.rand_graph(30, 0),
dgl.rand_graph(100, 30), dgl.rand_graph(30, 100),
dgl.rand_graph(100, 3000), dgl.rand_bipartite(30, 40, 300)
dgl.rand_bipartite(80, 160, 3000)
] ]
spmm_shapes = [ spmm_shapes = [
((1, 2, 1, 3, 1), (4, 1, 3, 1, 1)), ((1, 2, 1, 3, 1), (4, 1, 3, 1, 1)),
((5, 3, 1, 7), (1, 3, 7, 1)),
((1, 3, 1), (4, 1, 3)),
((3, 3), (1, 3)), ((3, 3), (1, 3)),
((1,), (3,)), ((1,), (3,)),
((3,), (1,)), ((3,), (1,)),
...@@ -89,7 +86,6 @@ sddmm_shapes = [ ...@@ -89,7 +86,6 @@ sddmm_shapes = [
((1, 2, 1, 3, 1), (4, 1, 3, 1, 1)), ((1, 2, 1, 3, 1), (4, 1, 3, 1, 1)),
((5, 3, 1, 7), (1, 3, 7, 7)), ((5, 3, 1, 7), (1, 3, 7, 7)),
((1, 3, 3), (4, 1, 3)), ((1, 3, 3), (4, 1, 3)),
((3, 3), (1, 3)),
((3,), (3,)), ((3,), (3,)),
((1,), (1,)) ((1,), (1,))
] ]
...@@ -101,8 +97,6 @@ sddmm_shapes = [ ...@@ -101,8 +97,6 @@ sddmm_shapes = [
@parametrize_dtype @parametrize_dtype
def test_spmm(idtype, g, shp, msg, reducer): def test_spmm(idtype, g, shp, msg, reducer):
g = g.astype(idtype).to(F.ctx()) g = g.astype(idtype).to(F.ctx())
if dgl.backend.backend_name == 'tensorflow' and (reducer in ['min', 'max']):
pytest.skip() # tensorflow dlpack has problem writing into int32 arrays on GPU.
print(g) print(g)
print(g.idtype) print(g.idtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment