Unverified Commit f25b1a06 authored by Zihao Ye's avatar Zihao Ye Committed by GitHub
Browse files

[Feature] Autograd of gspmm and gsddmm on PyTorch/MXNet/Tensorflow (#1680)

* init

* reverse(by minjie

* upd

* upd

* upd

* upd

* upd

* upd

* upd

* gpu

* upd

* upd

* upd

* upd

* udp

* upd

* upd

* imidiot

* fix

* upd

* upd

* upd

* udp

* upd

* upd

* fix

* udp

* upd

* upd

* upd

* upd

* upd

* fix

* remove redundency

* upd

* upd

* upd

* cache

* upd

* upd

* upd

* upd

* upd

* udp

* upd

* trigger

* upd

* fix

* upd

* unused import

* upd

* upd
parent c13903bf
......@@ -30,4 +30,3 @@ from .traversal import *
from .transform import *
from .propagate import *
from .udf import NodeBatch, EdgeBatch
from .sparse import gspmm, gsddmm
......@@ -7,6 +7,7 @@ import importlib
from . import backend
from .set_default_backend import set_default_backend
from itertools import product
_enabled_apis = set()
......@@ -18,6 +19,189 @@ def _gen_missing_api(api, mod_name):
' the DGLBACKEND environment.' % (api, mod_name))
return _missing_api
_notes_docstring = r"""
Notes
-----
This function supports autograd (computing input gradients given the output gradient). If the
feature shape of two input operands do not match, we first broadcasts the features to a unified
shape (note that the memory usage will not increase accordingly) and then performs the operation.
Broadcasting follows NumPy semantics. Please see
https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
for more details about the NumPy broadcasting semantics."""
def _gen_sddmm_func(lhs_target, rhs_target, binary_op):
name = "{}_{}_{}".format(lhs_target, binary_op, rhs_target)
target_dict = {
'u': "source node",
'e': "edge",
'v': "destination node"
}
lhs_str = target_dict[lhs_target]
rhs_str = target_dict[rhs_target]
docstring = r"""Generalized SDDMM function.
It computes edge features by {} {} features and {} features.
Parameters
----------
g : DGLHeteroGraph
The input graph
x : tensor
The {} features.
y : tensor
The {} features.
Returns
-------
tensor
The result tensor.
{}""".format(binary_op, lhs_str, rhs_str,
lhs_str, rhs_str,
_notes_docstring)
def func(g, x, y):
return gsddmm(g, binary_op, x, y,
lhs_target=lhs_target, rhs_target=rhs_target)
func.__name__ = name
func.__doc__ = docstring
return func
def _gen_spmm_func(binary_op, reduce_op):
name = "u_{}_e_{}".format(binary_op, reduce_op)
docstring = """Generalized SpMM function.
It fuses two steps into one kernel.
(1) Computes messages by {} source node and edge features.
(2) Aggregate the messages by {} as the features on destination nodes.
Parameters
----------
g : DGLHeteroGraph
The input graph
x : tensor
The source node features.
y : tensor
The edge features.
Returns
-------
tensor
The result tensor.
{}""".format(binary_op, reduce_op,
_notes_docstring)
def func(g, x, y):
return gspmm(g, binary_op, reduce_op, x, y)
func.__name__ = name
func.__doc__ = docstring
return func
def _gen_copy_reduce_func(binary_op, reduce_op):
name = "{}_{}".format(binary_op, reduce_op)
binary_str = {
"copy_u": "It copies node feature to edge as the message.",
'copy_e': "It regards edge feature as message."
}
x_str = {
"copy_u": "source node",
"copy_e": "edge"
}
docstring = lambda binary_op: """Generalized SpMM function. {}
Then aggregates the message by {} on destination nodes.
Parameters
----------
g : DGLHeteroGraph
The input graph
x : tensor
The {} features.
Returns
-------
tensor
The result tensor.
Notes
-----
This function supports autograd (computing input gradients given the output gradient).
""".format(
binary_str[binary_op],
reduce_op,
x_str[binary_op],
_notes_docstring)
def func(g, x):
if binary_op == 'copy_u':
return gspmm(g, 'copy_lhs', reduce_op, x, None)
else:
return gspmm(g, 'copy_rhs', reduce_op, None, x)
func.__name__ = name
func.__doc__ = docstring(binary_op)
return func
def _register_sddmm_func(mod, enabled_apis):
"""Register sddmm functions"""
target = ["u", "v", "e"]
for lhs, rhs in product(target, target):
if lhs != rhs:
for binary_op in ["add", "sub", "mul", "div", "dot"]:
func = _gen_sddmm_func(lhs, rhs, binary_op)
setattr(mod, func.__name__, func)
enabled_apis.add(func.__name__)
def _register_spmm_func(mod, enabled_apis):
"""Register spmm functions"""
for binary_op in ["add", "sub", "mul", "div", "copy_u", "copy_e"]:
for reduce_op in ["sum", "max", "min"]:
if binary_op.startswith("copy"):
func = _gen_copy_reduce_func(binary_op, reduce_op)
else:
func = _gen_spmm_func(binary_op, reduce_op)
setattr(mod, func.__name__, func)
enabled_apis.add(func.__name__)
def copy_u(g, x):
r"""Generalized SDDMM function that copies source node features to edges.
Parameters
----------
g : DGLHeteroGraph
The input graph.
x : tensor
The source node features.
Returns
-------
tensor
The result tensor.
Notes
-----
This function supports autograd (computing input gradients given the output gradient).
""".format(_notes_docstring)
return gsddmm(g, 'copy_lhs', x, None)
def copy_v(g, x):
r"""Generalized SDDMM function that copies destination node features to edges.
Parameters
----------
g : DGLHeteroGraph
The input graph.
x : tensor
The destination node features.
Returns
-------
tensor
The result tensor.
Notes
-----
This function supports autograd (computing input gradients given the output gradient).
""".format(_notes_docstring)
return gsddmm(g, 'copy_rhs', None, x)
def load_backend(mod_name):
print('Using backend: %s' % mod_name, file=sys.stderr)
......@@ -50,6 +234,12 @@ def load_backend(mod_name):
setattr(thismod, api, mod.__dict__[api])
else:
setattr(thismod, api, _gen_missing_api(api, mod_name))
_register_sddmm_func(thismod, _enabled_apis)
_register_spmm_func(thismod, _enabled_apis)
setattr(thismod, copy_u.__name__, copy_u)
_enabled_apis.add(copy_u.__name__)
setattr(thismod, copy_v.__name__, copy_v)
_enabled_apis.add(copy_v.__name__)
def get_preferred_backend():
......
......@@ -1263,6 +1263,21 @@ def zerocopy_to_dgl_ndarray(input):
"""
pass
def zerocopy_to_dgl_ndarray_for_write(input):
"""Zerocopy a framework-specific Tensor to dgl.ndarray.NDArray
that is ready for write (required in MXNet).
Parameters
----------
input : Tensor
Returns
-------
dgl.ndarray.NDArray
"""
pass
def zerocopy_from_dgl_ndarray(input):
"""Zerocopy a dgl.ndarray.NDArray to framework-specific Tensor
......@@ -1351,6 +1366,79 @@ def copy_reduce(reducer, graph, target, in_data, out_size, in_map, out_map):
"""
pass
def gspmm(g, op, reduce_op, lhs_data, rhs_data):
r""" Generalized Sparse Matrix Multiplication interface.
It fuses two steps into one kernel.
(1) Computes messages by :attr:`op` source node and edge features.
(2) Aggregate the messages by :attr:`reduce_op` as the features on destination nodes.
.. math::
x_v = \psi_{(u, v, e)\in \mathcal{G}}(\rho(x_u, x_e))
where :math:`x_v` is the returned feature on destination nodes, and :math`x_u`,
:math:`x_e` refers to :attr:`u`, :attr:`e` respectively. :math:`\rho` means binary
operator :attr:`op` and :math:`\psi` means reduce operator :attr:`reduce_op`,
:math:`\mathcal{G}` is the graph we apply gspmm on: :attr:`g`.
Note that this function does not handle gradients.
Parameters
----------
g : DGLHeteroGraph
The input graph.
op : str
The binary op's name, could be ``add``, ``sub``, ``mul``, ``div``,
``copy_lhs``, ``copy_rhs``.
reduce_op : str
Reduce operator, could be ``sum``, ``max``, ``min``.
lhs_data : tensor or None
The left operand, could be None if it's not required by the op.
rhs_data : tensor or None
The right operand, could be None if it's not required by the op.
Returns
-------
tensor
The result tensor.
"""
pass
def gsddmm(g, op, lhs_data, rhs_data, lhs_target='u', rhs_target='v'):
r""" Generalized Sampled-Dense-Dense Matrix Multiplication interface.
It computes edge features by :attr:`op` lhs features and rhs features.
.. math::
x_{e} = \phi(x_{lhs}, x_{rhs}), \forall (u,e,v)\in \mathcal{G}
where :math:`x_{e}` is the returned feature on edges and :math:`x_u`,
:math:`x_v` refers to :attr:`u`, :attr:`v` respectively. :math:`\phi`
is the binary operator :attr:`op`, and :math:`\mathcal{G}` is the graph
we apply gsddmm on: :attr:`g`. $lhs$ and $rhs$ are one of $u,v,e$'s.
Parameters
----------
g : DGLHeteroGraph
The input graph.
op : str
Binary operator, could be ``add``, ``sub``, ``mul``, ``div``, ``dot``,
``copy_lhs``, ``copy_rhs``.
lhs_data : tensor or None
The left operand, could be None if it's not required by op.
rhs_data : tensor or None
The right operand, could be None if it's not required by op.
lhs_target: str
Choice of `u`(source), `e`(edge) or `v`(destination) for left operand.
rhs_target: str
Choice of `u`(source), `e`(edge) or `v`(destination) for right operand.
Returns
-------
tensor
The result tensor.
"""
pass
###############################################################################
# Other interfaces
# ----------------
......
from .tensor import *
from .sparse import *
import mxnet as mx
import numpy as np
from mxnet import nd
from ...sparse import _gspmm, _gsddmm
from ...base import dgl_warning
from .tensor import asnumpy, copy_to, zerocopy_from_numpy, context
def _scatter_nd(index, src, n_rows):
assert index.shape == src.shape
dgl_warning("MXNet do not support scatter_add, fallback to numpy.")
ctx = context(src)
index = asnumpy(index)
src = asnumpy(src)
shp = index.shape
ndim = src.ndim
offsets = []
stride = 1
for i in reversed(range(1, ndim)):
di = shp[i]
offset_i = np.arange(di, dtype=index.dtype)
offsets.append(
(stride * offset_i).reshape((1,) * i + (di,) + (1,) * (ndim - 1 - i)))
stride *= di
new_idx = index * stride + sum(offsets)
src = src.reshape(-1)
new_idx = new_idx.reshape(-1)
rst = np.zeros((stride * n_rows,), dtype=src.dtype)
np.add.at(rst, new_idx, src)
rst = rst.reshape(n_rows, *shp[1:])
rst = copy_to(zerocopy_from_numpy(rst), ctx)
return rst
def _gather_nd(index, src):
ctx = context(src)
shp = index.shape
ndim = src.ndim
offsets = []
stride = 1
for i in reversed(range(1, ndim)):
di = shp[i]
offset_i = nd.arange(di, dtype=index.dtype)
offsets.append(
(stride * offset_i).reshape((1,) * i + (di,) + (1,) * (ndim - 1 - i)))
stride *= di
new_idx = index * stride + copy_to(sum(offsets), ctx)
src = src.reshape(-1)
new_idx = new_idx.reshape(-1)
rst = nd.take(src, new_idx).reshape(shp)
return rst
def _reduce_grad(grad, shape):
"""Reduce gradient on the broadcast dimension
If there is broadcast in forward pass, gradients need to be reduced on
broadcast dimension. This function checks the input tensor shape and
gradient shape and perform the reduction.
Parameters
----------
grad: Tensor
Gradient tensor
shape: tuple
Shape of input tensor
Returns
-------
Tensor
"""
grad_shape = grad.shape[1:]
in_shape = shape[1:]
if in_shape == grad_shape:
# no need to reduce
return grad
num_to_squeeze = len(grad_shape) - len(in_shape)
# pad inshape
in_shape = (1,) * num_to_squeeze + in_shape
# pad in_shape
in_shape = (1,) * num_to_squeeze + in_shape
reduce_idx = np.nonzero(np.asarray(grad_shape) - np.asarray(in_shape))[0]
reduce_idx += 1 # skip batch dim
grad = grad.sum(axis=tuple(reduce_idx), keepdims=True)
return grad.reshape(shape)
def _muldiv(op, x):
return 1. / x if op == 'div' else x
def _addsub(op, x):
return -x if op == 'sub' else x
class GSpMM(mx.autograd.Function):
def __init__(self, g, op, reduce_op):
super(GSpMM, self).__init__()
self.gidx = g._graph
self.op = op
self.reduce_op = reduce_op
def forward(self, X, Y):
out, (argX, argY) = _gspmm(self.gidx, self.op, self.reduce_op, X, Y)
self.save_for_backward(X, Y, argX, argY)
return out
def backward(self, dZ):
ctx = context(dZ)
X, Y, argX, argY = self.saved_tensors
gidx, op, reduce_op = self.gidx, self.op, self.reduce_op
dX, dY = nd.empty((), ctx=ctx), nd.empty((), ctx=ctx)
if op != 'copy_rhs' and X.grad is not None:
g_rev = gidx.reverse()
if reduce_op == 'sum':
if op in ['mul', 'div']:
dX = _gspmm(g_rev, 'mul', 'sum', dZ, _muldiv(op, Y))[0]
elif op in ['add', 'sub']:
dX = _gspmm(g_rev, 'copy_lhs', 'sum', dZ, Y)[0]
elif op == 'copy_lhs':
dX = _gspmm(g_rev, 'copy_lhs', 'sum', dZ, None)[0]
else:
if op in ['mul', 'div']:
dX = _scatter_nd(
argX,
_muldiv(op, _gather_nd(argY, Y.broadcast_to((Y.shape[0], *dZ.shape[1:])))) * dZ,
X.shape[0])
elif op in ['add', 'sub', 'copy_lhs']:
dX = _scatter_nd(argX, dZ, X.shape[0])
dX = _reduce_grad(dX, X.shape)
if op != 'copy_lhs' and Y.grad is not None:
if reduce_op == 'sum':
if op in ['mul', 'div']:
dY = _gsddmm(gidx, 'mul', X, dZ)
if op == 'div': dY = -dY / (Y ** 2)
elif op in ['add', 'sub', 'copy_rhs']:
dY = _gsddmm(gidx, 'copy_rhs', X, _addsub(op, dZ))
else:
if op in ['mul', 'div']:
dY = _scatter_nd(
argY,
_gather_nd(argX, X.broadcast_to((X.shape[0], *dZ.shape[1:]))) * dZ,
Y.shape[0])
if op == 'div': dY = -dY / (Y ** 2)
elif op in ['add', 'sub', 'copy_rhs']:
dY = _scatter_nd(argY, _addsub(op, dZ), Y.shape[0])
dY = _reduce_grad(dY, Y.shape)
self.saved_tensors = None
return dX, dY
def gspmm(g, op, reduce_op, lhs_data, rhs_data):
func = GSpMM(g, op, reduce_op)
return func(lhs_data, rhs_data)
class GSDDMM(mx.autograd.Function):
def __init__(self, g, op, lhs_target, rhs_target):
super(GSDDMM, self).__init__()
self.gidx = g._graph
self.op = op
self.lhs_target = lhs_target
self.rhs_target = rhs_target
def forward(self, X, Y):
out = _gsddmm(self.gidx, self.op, X, Y, self.lhs_target, self.rhs_target)
self.save_for_backward(X, Y)
return out
def backward(self, dZ):
ctx = context(dZ)
X, Y = self.saved_tensors
gidx, op = self.gidx, self.op
lhs_target, rhs_target = self.lhs_target, self.rhs_target
dX, dY = nd.empty((), ctx=ctx), nd.empty((), ctx=ctx)
if op != 'copy_rhs':
if lhs_target in ['u', 'v']:
_gidx = gidx if self.lhs_target == 'v' else gidx.reverse()
if op in ['add', 'sub', 'copy_lhs']:
dX = _gspmm(_gidx, 'copy_rhs', 'sum', None, dZ)[0]
else: # mul, div, dot
if rhs_target == lhs_target:
dX = _gspmm(_gidx, 'copy_rhs', 'sum', None, dZ)[0] * _muldiv(op, Y)
elif self.rhs_target == 'e':
dX = _gspmm(_gidx, 'copy_rhs', 'sum', None, dZ * _muldiv(op, Y))[0]
else: # rhs_target = !lhs_target
dX = _gspmm(_gidx, 'mul', 'sum', _muldiv(op, Y), dZ)[0]
else: # lhs_target == 'e'
if op in ['add', 'sub', 'copy_lhs']:
dX = dZ
else: # mul, div, dot
dX = _gsddmm(gidx, 'mul', dZ, _muldiv(op, Y), 'e', rhs_target)
dX = _reduce_grad(dX, X.shape)
if op != 'copy_lhs':
if self.rhs_target in ['u', 'v']:
_gidx = gidx if rhs_target == 'v' else gidx.reverse()
if op in ['add', 'sub', 'copy_rhs']:
dY = _gspmm(_gidx, 'copy_rhs', 'sum', None, _addsub(op, dZ))[0]
else: # mul, div, dot
if lhs_target == rhs_target:
dY = _gspmm(_gidx, 'copy_rhs', 'sum', None, dZ)[0] * X
elif self.lhs_target == 'e':
dY = _gspmm(_gidx, 'copy_rhs', 'sum', None, dZ * X)[0]
else: # rhs_target = !lhs_target
dY = _gspmm(_gidx, 'mul', 'sum', X, dZ)[0]
if op == 'div': dY = -dY / (Y ** 2)
else:
if op in ['add', 'sub', 'copy_rhs']:
dY = _addsub(op, dZ)
else: # mul, div, dot
dY = _gsddmm(gidx, 'mul', dZ, X, 'e', lhs_target)
if op == 'div': dY = -dY / (Y ** 2)
dY = _reduce_grad(dY, Y.shape)
self.saved_tensors = None
return dX, dY
def gsddmm(g, op, lhs_data, rhs_data, lhs_target='u', rhs_target='v'):
func = GSDDMM(g, op, lhs_target, rhs_target)
return func(lhs_data, rhs_data)
......@@ -10,7 +10,7 @@ import numbers
import builtins
from ... import ndarray as dglnd
from ... import kernel as K
from ...function.base import TargetCode
from ...function.base import TargetCode
MX_VERSION = LooseVersion(mx.__version__)
if MX_VERSION.version[0] == 1 and MX_VERSION.version[1] < 5:
......
from .tensor import *
from .sparse import *
import torch as th
from ...sparse import _gspmm, _gsddmm
__all__ = ['gspmm', 'gsddmm']
def _reduce_grad(grad, shape):
"""Reduce gradient on the broadcast dimension
If there is broadcast in forward pass, gradients need to be reduced on
broadcast dimension. This function checks the input tensor shape and
gradient shape and perform the reduction.
Parameters
----------
grad: Tensor
Gradient tensor
shape: tuple
Shape of input tens
or
Returns
-------
Tensor
"""
grad_shape = grad.shape[1:]
in_shape = shape[1:]
if in_shape == grad_shape:
# no need to reduce
return grad
num_to_squeeze = len(grad_shape) - len(in_shape)
# pad inshape
in_shape = (1,) * num_to_squeeze + in_shape
reduce_idx = th.nonzero(th.tensor(grad_shape) - th.tensor(in_shape))
reduce_idx += 1 # skip batch dim
if len(reduce_idx) > 0:
grad = grad.sum(dim=tuple(reduce_idx), keepdim=True)
return grad.view(-1, *shape[1:])
def _muldiv(op, x):
return 1. / x if op == 'div' else x
def _addsub(op, x):
return -x if op == 'sub' else x
class GSpMM(th.autograd.Function):
@staticmethod
def forward(ctx, g, op, reduce_op, X, Y):
gidx = g._graph
out, (argX, argY) = _gspmm(gidx, op, reduce_op, X, Y)
ctx.backward_cache = gidx, op, reduce_op
ctx.save_for_backward(X, Y, argX, argY)
return out
@staticmethod
def backward(ctx, dZ):
gidx, op, reduce_op = ctx.backward_cache
X, Y, argX, argY = ctx.saved_tensors
dX, dY = None, None
if op != 'copy_rhs' and ctx.needs_input_grad[3]:
g_rev = gidx.reverse()
if reduce_op == 'sum':
if op in ['mul', 'div']:
dX = _gspmm(g_rev, 'mul', 'sum', dZ, _muldiv(op, Y))[0]
elif op in ['add', 'sub']:
dX = _gspmm(g_rev, 'copy_lhs', 'sum', dZ, Y)[0]
elif op == 'copy_lhs':
dX = _gspmm(g_rev, 'copy_lhs', 'sum', dZ, None)[0]
else:
dX = th.zeros((X.shape[0],) + dZ.shape[1:], dtype=X.dtype, device=X.device)
if op in ['mul', 'div']:
dX.scatter_add_(0, argX.long(),
_muldiv(op, Y.expand(-1, *dZ.shape[1:]).gather(0, argY.long())) * dZ)
elif op in ['add', 'sub', 'copy_lhs']:
dX.scatter_add_(0, argX.long(), dZ)
dX = _reduce_grad(dX, X.shape)
if op != 'copy_lhs' and ctx.needs_input_grad[4]:
if reduce_op == 'sum':
if op in ['mul', 'div']:
dY = _gsddmm(gidx, 'mul', X, dZ)
if op == 'div': dY = -dY / (Y ** 2)
elif op in ['add', 'sub', 'copy_rhs']:
dY = _gsddmm(gidx, 'copy_rhs', X, _addsub(op, dZ))
else:
dY = th.zeros((Y.shape[0],) + dZ.shape[1:], dtype=Y.dtype, device=Y.device)
if op in ['mul', 'div']:
dY.scatter_add_(0, argY.long(),
X.expand(-1, *dZ.shape[1:]).gather(0, argX.long()) * dZ)
if op == 'div': dY = -dY / (Y ** 2)
elif op in ['add', 'sub', 'copy_rhs']:
dY.scatter_add_(0, argY.long(), _addsub(op, dZ))
dY = _reduce_grad(dY, Y.shape)
return None, None, None, dX, dY
class GSDDMM(th.autograd.Function):
@staticmethod
def forward(ctx, g, op, X, Y, lhs_target, rhs_target):
gidx = g._graph
out = _gsddmm(gidx, op, X, Y, lhs_target, rhs_target)
ctx.backward_cache = gidx, op, lhs_target, rhs_target
ctx.save_for_backward(X, Y)
return out
@staticmethod
def backward(ctx, dZ):
gidx, op, lhs_target, rhs_target = ctx.backward_cache
X, Y = ctx.saved_tensors
dX, dY = None, None
if op != 'copy_rhs' and ctx.needs_input_grad[2]:
if lhs_target in ['u', 'v']:
_gidx = gidx if lhs_target == 'v' else gidx.reverse()
if op in ['add', 'sub', 'copy_lhs']:
dX = _gspmm(_gidx, 'copy_rhs', 'sum', None, dZ)[0]
else: # mul, div, dot
if rhs_target == lhs_target:
dX = _gspmm(_gidx, 'copy_rhs', 'sum', None, dZ)[0] * _muldiv(op, Y)
elif rhs_target == 'e':
dX = _gspmm(_gidx, 'copy_rhs', 'sum', None, dZ * _muldiv(op, Y))[0]
else: # rhs_target = !lhs_target
dX = _gspmm(_gidx, 'mul', 'sum', _muldiv(op, Y), dZ)[0]
else: # lhs_target == 'e'
if op in ['add', 'sub', 'copy_lhs']:
dX = dZ
else: # mul, div, dot
dX = _gsddmm(gidx, 'mul', dZ, _muldiv(op, Y), 'e', rhs_target)
dX = _reduce_grad(dX, X.shape)
if op != 'copy_lhs' and ctx.needs_input_grad[3]:
if rhs_target in ['u', 'v']:
_gidx = gidx if rhs_target == 'v' else gidx.reverse()
if op in ['add', 'sub', 'copy_rhs']:
dY = _gspmm(_gidx, 'copy_rhs', 'sum', None, _addsub(op, dZ))[0]
else: # mul, div, dot
if lhs_target == rhs_target:
dY = _gspmm(_gidx, 'copy_rhs', 'sum', None, dZ)[0] * X
elif lhs_target == 'e':
dY = _gspmm(_gidx, 'copy_rhs', 'sum', None, dZ * X)[0]
else: # rhs_target = !lhs_target
dY = _gspmm(_gidx, 'mul', 'sum', X, dZ)[0]
if op == 'div': dY = -dY / (Y ** 2)
else:
if op in ['add', 'sub', 'copy_rhs']:
dY = _addsub(op, dZ)
else: # mul, div, dot
dY = _gsddmm(gidx, 'mul', dZ, X, 'e', lhs_target)
if op == 'div': dY = -dY / (Y ** 2)
dY = _reduce_grad(dY, Y.shape)
return None, None, dX, dY, None, None
def gspmm(g, op, reduce_op, lhs_data, rhs_data):
return GSpMM.apply(g, op, reduce_op, lhs_data, rhs_data)
def gsddmm(g, op, lhs_data, rhs_data, lhs_target='u', rhs_target='v'):
return GSDDMM.apply(g, op, lhs_data, rhs_data, lhs_target, rhs_target)
......@@ -309,6 +309,9 @@ def zerocopy_from_numpy(np_array):
def zerocopy_to_dgl_ndarray(input):
return nd.from_dlpack(dlpack.to_dlpack(input.contiguous()))
def zerocopy_to_dgl_ndarray_for_write(input):
return zerocopy_to_dgl_ndarray(input)
def zerocopy_from_dgl_ndarray(input):
return dlpack.from_dlpack(input.to_dlpack())
......
......@@ -2,3 +2,4 @@ import os
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'
from .tensor import *
from .sparse import *
import tensorflow as tf
import numpy as np
from .tensor import tensor, copy_to, context
from ...sparse import _gspmm, _gsddmm
def _scatter_nd(index, src, n_rows):
assert index.shape == src.shape
shp = index.shape
ctx = context(src)
ndim = index.ndim
offsets = []
stride = 1
for i in reversed(range(1, ndim)):
di = shp[i]
offset_i = tf.range(di, dtype=index.dtype)
offsets.append(
tf.reshape((stride * offset_i), (1,) * i + (di,) + (1,) * (ndim - 1 - i)))
stride *= di
new_idx = index * stride + copy_to(sum(offsets), ctx)
src = tf.reshape(src, (-1,))
new_idx = tf.reshape(new_idx, (-1, 1))
rst = tf.reshape(tf.scatter_nd(new_idx, src, (stride * n_rows,)), (n_rows, *shp[1:]))
return rst
def _gather_nd(index, src):
shp = index.shape
ctx = context(src)
ndim = index.ndim
offsets = []
stride = 1
for i in reversed(range(1, ndim)):
di = shp[i]
offset_i = tf.range(di, dtype=index.dtype)
offsets.append(
tf.reshape((stride * offset_i), (1,) * i + (di,) + (1,) * (ndim - 1 - i)))
stride *= di
new_idx = index * stride + copy_to(sum(offsets), ctx)
src = tf.reshape(src, (-1,))
new_idx = tf.reshape(new_idx, (-1))
print(src, new_idx)
rst = tf.reshape(tf.gather(src, new_idx), shp)
return rst
def _reduce_grad(grad, shape):
"""Reduce gradient on the broadcast dimension
If there is broadcast in forward pass, gradients need to be reduced on
broadcast dimension. This function checks the input tensor shape and
gradient shape and perform the reduction.
Parameters
----------
grad: Tensor
Gradient tensor
shape: tuple
Shape of input tensor
Returns
-------
Tensor
"""
grad_shape = grad.shape[1:]
in_shape = shape[1:]
if in_shape == grad_shape:
# no need to reduce
return grad
num_to_squeeze = len(grad_shape) - len(in_shape)
# pad inshape
in_shape = (1,) * num_to_squeeze + in_shape
reduce_idx = np.asarray(np.nonzero(np.asarray(grad_shape) - np.asarray(in_shape)))
reduce_idx += 1 # skip batch dim
reduce_idx_tensor = tf.constant(tuple(
reduce_idx.flatten().tolist()))
grad = tf.reduce_sum(grad, axis=reduce_idx_tensor, keepdims=True)
return tf.reshape(grad, shape)
def _muldiv(op, x):
return 1. / x if op == 'div' else x
def _addsub(op, x):
return -x if op == 'sub' else x
def gspmm_real(g, op, reduce_op, X, Y):
gidx = g._graph
out, (argX, argY) = _gspmm(gidx, op, reduce_op, X, Y)
def grad(dZ):
dZ = tensor(dZ)
dX, dY = tf.zeros(()), tf.zeros(())
if op != 'copy_rhs':
g_rev = gidx.reverse()
if reduce_op == 'sum':
if op in ['mul', 'div']:
dX = _gspmm(g_rev, 'mul', 'sum', dZ, _muldiv(op, Y))[0]
elif op in ['add', 'sub']:
dX = _gspmm(g_rev, 'copy_lhs', 'sum', dZ, Y)[0]
elif op == 'copy_lhs':
dX = _gspmm(g_rev, 'copy_lhs', 'sum', dZ, None)[0]
else:
if op in ['mul', 'div']:
dX = _scatter_nd(
argX,
_muldiv(op, _gather_nd(argY, tf.broadcast_to(Y, (Y.shape[0], *dZ.shape[1:])))) * dZ,
X.shape[0])
elif op in ['add', 'sub', 'copy_lhs']:
dX = _scatter_nd(argX, dZ, X.shape[0])
dX = _reduce_grad(dX, X.shape)
if op != 'copy_lhs':
if reduce_op == 'sum':
if op in ['mul', 'div']:
dY = _gsddmm(gidx, 'mul', X, dZ)
if op == 'div': dY = -dY / (Y ** 2)
elif op in ['add', 'sub', 'copy_rhs']:
dY = _gsddmm(gidx, 'copy_rhs', X, _addsub(op, dZ))
else:
out_shp = (Y.shape[0],) + dZ.shape[1:]
if op in ['mul', 'div']:
dY = _scatter_nd(
argY,
_gather_nd(argX, tf.broadcast_to(X, (X.shape[0], *dZ.shape[1:]))) * dZ,
Y.shape[0])
if op == 'div': dY = -dY / (Y ** 2)
elif op in ['add', 'sub', 'copy_rhs']:
dY = _scatter_nd(argY, _addsub(op, dZ), Y.shape[0])
dY = _reduce_grad(dY, Y.shape)
return dX, dY
return out, grad
def gspmm(g, op, reduce_op, X, Y):
@tf.custom_gradient
def _lambda(X, Y):
return gspmm_real(g, op, reduce_op, X, Y)
return _lambda(X, Y)
def gsddmm_real(g, op, X, Y, lhs_target, rhs_target):
gidx = g._graph
out = _gsddmm(gidx, op, X, Y, lhs_target, rhs_target)
def grad(dZ):
dX, dY = tf.zeros(()), tf.zeros(())
if op != 'copy_rhs':
if lhs_target in ['u', 'v']:
_gidx = gidx if lhs_target == 'v' else gidx.reverse()
if op in ['add', 'sub', 'copy_lhs']:
dX = _gspmm(_gidx, 'copy_rhs', 'sum', None, dZ)[0]
else: # mul, div, dot
if rhs_target == lhs_target:
dX = _gspmm(_gidx, 'copy_rhs', 'sum', None, dZ)[0] * _muldiv(op, Y)
elif rhs_target == 'e':
dX = _gspmm(_gidx, 'copy_rhs', 'sum', None, dZ * _muldiv(op, Y))[0]
else: # rhs_target = !lhs_target
dX = _gspmm(_gidx, 'mul', 'sum', _muldiv(op, Y), dZ)[0]
else: # lhs_target == 'e'
if op in ['add', 'sub', 'copy_lhs']:
dX = dZ
else: # mul, div, dot
dX = _gsddmm(gidx, 'mul', dZ, _muldiv(op, Y), 'e', rhs_target)
dX = _reduce_grad(dX, X.shape)
if op != 'copy_lhs':
if rhs_target in ['u', 'v']:
_gidx = gidx if rhs_target == 'v' else gidx.reverse()
if op in ['add', 'sub', 'copy_rhs']:
dY = _gspmm(_gidx, 'copy_rhs', 'sum', None, _addsub(op, dZ))[0]
else: # mul, div, dot
if lhs_target == rhs_target:
dY = _gspmm(_gidx, 'copy_rhs', 'sum', None, dZ)[0] * X
elif lhs_target == 'e':
dY = _gspmm(_gidx, 'copy_rhs', 'sum', None, dZ * X)[0]
else: # rhs_target = !lhs_target
dY = _gspmm(_gidx, 'mul', 'sum', X, dZ)[0]
if op == 'div': dY = -dY / (Y ** 2)
else:
if op in ['add', 'sub', 'copy_rhs']:
dY = _addsub(op, dZ)
else: # mul, div, dot
dY = _gsddmm(gidx, 'mul', dZ, X, 'e', lhs_target)
if op == 'div': dY = -dY / (Y ** 2)
dY = _reduce_grad(dY, Y.shape)
return dX, dY
return out, grad
def gsddmm(g, op, X, Y, lhs_target='u', rhs_target='v'):
@tf.custom_gradient
def _lambda(X, Y):
return gsddmm_real(g, op, X, Y, lhs_target, rhs_target)
return _lambda(X, Y)
......@@ -417,6 +417,8 @@ def zerocopy_from_numpy(np_array):
def zerocopy_to_dgl_ndarray(input):
return nd.from_dlpack(zerocopy_to_dlpack(input))
def zerocopy_to_dgl_ndarray_for_write(input):
return zerocopy_to_dgl_ndarray(input)
def zerocopy_from_dgl_ndarray(input):
return zerocopy_from_dlpack(input.to_dlpack())
......
......@@ -991,21 +991,17 @@ class HeteroGraphIndex(ObjectBase):
"""
return _CAPI_DGLHeteroGetFormatGraph(self, restrict_format)
def reverse(self, metagraph):
@utils.cached_member(cache='_cache', prefix='reverse')
def reverse(self):
"""Reverse the heterogeneous graph adjacency
The node types and edge types are not changed
Parameters
----------
metagraph : GraphIndex
Meta-graph.
The node types and edge types are not changed.
Returns
-------
A new graph index.
"""
return _CAPI_DGLHeteroReverse(metagraph, self)
return _CAPI_DGLHeteroReverse(self)
@register_object('graph.HeteroSubgraph')
class HeteroSubgraphIndex(ObjectBase):
......
......@@ -2,12 +2,8 @@
# pylint: disable= no-member, arguments-differ
import torch as th
from ...function import TargetCode
from ...base import ALL, is_all
from ... import backend as F
from ... import utils
from ...graph import DGLGraph
from ...heterograph import DGLHeteroGraph
__all__ = ['edge_softmax']
......@@ -49,36 +45,12 @@ class EdgeSoftmax(th.autograd.Function):
if not is_all(eids):
g = g.edge_subgraph(eids.long())
n_nodes = g.number_of_dst_nodes()
n_edges = g.number_of_edges()
# TODO(BarclayII): this is a temporary fix of memory leakage in PyTorch
# in PR #1139. We should investigate further on what was actually happening
# when implementing EdgeSoftmax with message passing API instead of
# operators.
score_context = utils.to_dgl_context(score.device)
if isinstance(g, DGLGraph):
gidx = g._graph.get_immutable_gidx(score_context)
elif isinstance(g, DGLHeteroGraph):
assert g._graph.number_of_etypes() == 1, \
"EdgeSoftmax only support one edge type"
gidx = g._graph.get_unitgraph(0, score_context)
ctx.backward_cache = n_nodes, n_edges, gidx
#g.update_all(fn.copy_e('s', 'm'), fn.max('m', 'smax'))
smax = F.copy_reduce('max', gidx, TargetCode.EDGE, score, n_nodes)
#g.apply_edges(fn.e_sub_v('s', 'smax', 'out'))
out = F.binary_reduce(
'none', 'sub', gidx, TargetCode.EDGE, TargetCode.DST, score, smax, n_edges)
#g.edata['out'] = th.exp(g.edata['out'])
out = th.exp(out)
#g.update_all(fn.copy_e('out', 'm'), fn.sum('m', 'out_sum'))
out_sum = F.copy_reduce('sum', gidx, TargetCode.EDGE, out, n_nodes)
#g.apply_edges(fn.e_div_v('out', 'out_sum', 'out'))
out = F.binary_reduce(
'none', 'div', gidx, TargetCode.EDGE, TargetCode.DST, out, out_sum, n_edges)
score_max = F.copy_e_max(g, score)
score = th.exp(F.e_sub_v(g, score, score_max))
score_sum = F.copy_e_sum(g, score)
out = F.e_div_v(g, score, score_sum)
ctx.backward_cache = g
ctx.save_for_backward(out)
return out
......@@ -95,22 +67,14 @@ class EdgeSoftmax(th.autograd.Function):
out = dgl.EData(g, out)
sds = out * grad_out # type dgl.EData
sds_sum = sds.dst_sum() # type dgl.NData
grad_score = sds - sds * sds_sum # multiple expressions
grad_score = sds - out * sds_sum # multiple expressions
return grad_score.data
"""
n_nodes, n_edges, gidx = ctx.backward_cache
g = ctx.backward_cache
out, = ctx.saved_tensors
#g.edata['grad_s'] = out * grad_out
grad_s = out * grad_out
#g.update_all(fn.copy_e('grad_s', 'm'), fn.sum('m', 'accum'))
accum = F.copy_reduce('sum', gidx, TargetCode.EDGE, grad_s, n_nodes)
#g.apply_edges(fn.e_mul_v('out', 'accum', 'out'))
out = F.binary_reduce(
'none', 'mul', gidx, TargetCode.EDGE, TargetCode.DST, out, accum, n_edges)
#grad_score = g.edata['grad_s'] - g.edata['out']
grad_score = grad_s - out
sds = out * grad_out
accum = F.copy_e_sum(g, sds)
grad_score = sds - F.e_mul_v(g, out, accum)
return None, grad_score, None
......
"""Module for sparse matrix operators."""
# pylint: disable= invalid-name
from __future__ import absolute_import
import dgl.ndarray as nd
from ._ffi.function import _init_api
from .base import DGLError
......@@ -19,7 +21,7 @@ def infer_broadcast_shape(op, shp1, shp2):
Parameters
----------
op : str
The binary op's name, could be `add`, `sub`, `mul`, `div`, `dot`, `copy_u`, `copy_e`.
The binary op's name, could be `add`, `sub`, `mul`, `div`, `dot`, `copy_lhs`, `copy_rhs`.
shp1 : tuple[int]
The shape of lhs operand.
shp2 : tuple[int]
......@@ -36,9 +38,9 @@ def infer_broadcast_shape(op, shp1, shp2):
raise DGLError("Dot operator is only available for arrays with the "
"same size on last dimension, but got {} and {}."
.format(shp1, shp2))
if op == "copy_u":
if op == "copy_lhs":
return shp1
if op == "copy_e":
if op == "copy_rhs":
return shp2
# operands are padded to have the same dimensionality with leading 1's.
if len(shp1) > len(shp2):
......@@ -56,23 +58,20 @@ def to_dgl_nd(x):
"""Convert framework-specific tensor/None to dgl ndarray."""
return nd.NULL['int64'] if x is None else F.zerocopy_to_dgl_ndarray(x)
# map alias of operator name to its actually name that backend could recognize.
op_mapping = {
'+': 'add',
'-': 'sub',
'*': 'mul',
'/': 'div',
'.': 'dot',
'add': 'add',
'sub': 'sub',
'mul': 'mul',
'div': 'div',
'dot': 'dot',
'copy_u': 'copy_u',
'copy_e': 'copy_e'
def to_dgl_nd_for_write(x):
"""Convert framework-specific tensor/None to dgl ndarray for write."""
return nd.NULL['int64'] if x is None else F.zerocopy_to_dgl_ndarray_for_write(x)
target_mapping = {
'u': 0,
'e': 1,
'v': 2,
'src': 0,
'edge': 1,
'dst': 2
}
def gspmm(g, op, reduce_op, u, e):
def _gspmm(gidx, op, reduce_op, u, e):
r""" Generalized Sparse Matrix Multiplication interface. It takes the result of
:attr:`op` on source node feature and edge feature, leads to a message on edge.
Then aggregates the message by :attr:`reduce_op` on destination nodes.
......@@ -89,22 +88,25 @@ def gspmm(g, op, reduce_op, u, e):
Parameters
----------
g : DGLHeteroGraph
The input graph.
gidx : HeteroGraphIndex
The input graph index.
op : str
The binary op's name, could be ``add``, ``sub``, ``mul``, ``div``, ``dot``, ``copy_u``,
``copy_e``, or their alias ``+``, ``-``, ``*``, ``/``, ``.``.
The binary op's name, could be ``add``, ``sub``, ``mul``, ``div``, ``copy_lhs``,
``copy_rhs``.
reduce_op : str
Reduce operator, could be ``sum``, ``max``, ``min``.
u : tensor or None
The feature on source nodes, could be None if op is ``copy_e``.
The feature on source nodes, could be None if op is ``copy_rhs``.
e : tensor or None
The feature on edges, could be None if op is ``copy_u``.
The feature on edges, could be None if op is ``copy_lhs``.
Returns
-------
tensor
The result tensor.
tuple
The returned tuple is composed of two elements:
- The first element refers to the result tensor.
- The second element refers to a tuple composed of arg_u and arg_e
(which is useful when reducer is `min`/`max`).
Notes
-----
......@@ -112,40 +114,53 @@ def gspmm(g, op, reduce_op, u, e):
we expand its dimension with an additional dimension of length one. (e.g.
(90,) to (90, 1) for a graph with 90 nodes/edges).
"""
if u is not None:
if gidx.number_of_etypes() != 1:
raise DGLError("We only support gsddmm on graph with one edge type")
use_u = op != 'copy_rhs'
use_e = op != 'copy_lhs'
if use_u:
if F.ndim(u) == 1:
u = F.unsqueeze(u, -1)
if e is not None:
if use_e:
if F.ndim(e) == 1:
e = F.unsqueeze(e, -1)
if gidx.number_of_etypes() != 1:
raise DGLError("We only support gspmm on graph with one edge type")
op = op_mapping[op]
ctx = F.context(u) if u is not None else F.context(e)
dtype = F.dtype(u) if u is not None else F.dtype(e)
use_u = (op != 'copy_e')
use_e = (op != 'copy_u')
ctx = F.context(u) if use_u else F.context(e)
dtype = F.dtype(u) if use_u else F.dtype(e)
u_shp = F.shape(u) if use_u else (0,)
e_shp = F.shape(e) if use_e else (0,)
v_shp = (g.number_of_dst_nodes(), ) +\
_, dsttype = gidx.metagraph.find_edge(0)
v_shp = (gidx.number_of_nodes(dsttype), ) +\
infer_broadcast_shape(op, u_shp[1:], e_shp[1:])
v = F.zeros(v_shp, dtype, ctx)
use_cmp = reduce_op in ['max', 'min']
arg_u = F.zeros(v_shp, g.idtype, ctx) if use_cmp and use_u else None
arg_e = F.zeros(v_shp, g.idtype, ctx) if use_cmp and use_e else None
if g.number_of_edges() > 0:
gidx = g._graph.get_unitgraph(0, to_dgl_context(ctx))
_CAPI_DGLKernelSpMM(gidx, op, reduce_op,
to_dgl_nd(u), to_dgl_nd(e), to_dgl_nd(v),
to_dgl_nd(arg_u), to_dgl_nd(arg_e))
arg_u, arg_e = None, None
ugi = gidx.get_unitgraph(0, to_dgl_context(ctx))
idtype = getattr(F, ugi.dtype)
if use_cmp:
if use_u:
arg_u = F.zeros(v_shp, idtype, ctx)
if use_e:
arg_e = F.zeros(v_shp, idtype, ctx)
if gidx.number_of_edges(0) > 0:
_CAPI_DGLKernelSpMM(ugi, op, reduce_op,
to_dgl_nd(u if use_u else None),
to_dgl_nd(e if use_e else None),
to_dgl_nd_for_write(v),
to_dgl_nd_for_write(arg_u),
to_dgl_nd_for_write(arg_e))
return v, (arg_u, arg_e)
def gsddmm(g, op, u, v):
def _gsddmm(gidx, op, lhs, rhs, lhs_target='u', rhs_target='v'):
r""" Generalized Sampled-Dense-Dense Matrix Multiplication interface. It
takes the result of :attr:`op` on source node feature and destination node
feature, leads to a feature on edge.
.. math::
x_{e} = \phi(x_u, x_v), \forall (u,e,v)\in \mathcal{G}
x_{e} = \phi(x_u, x_e, x_v), \forall (u,e,v)\in \mathcal{G}
where :math:`x_{e}` is the returned feature on edges and :math:`x_u`,
:math:`x_v` refers to :attr:`u`, :attr:`v` respectively. :math:`\phi`
......@@ -154,15 +169,21 @@ def gsddmm(g, op, u, v):
Parameters
----------
g : DGLHeteroGraph
The input graph.
gidx : HeteroGraphIndex
The input graph index.
op : str
Binary operator, could be ``add``, ``sub``, ``mul``, ``div``, ``dot``, ``copy_u``,
or their alias ``+``, ``-``, ``*``, ``/``, ``.``.
u : tensor or None
The feature on source nodes.
v : tensor or None
The feature on destination, could be None if op is ``copy_u``.
Binary operator, could be ``add``, ``sub``, ``mul``, ``div``, ``dot``,
``copy_lhs``, ``copy_rhs``.
lhs : tensor or None
Left hand operand.
rhs : tensor or None
Right hand operand.
lhs_target : str
The target of left hand operand, could be ``src``, ``edge``, ``dst``
or their alias ``u``, ``e``, ``v``.
rhs_target : str
The target of right hand operand, could be ``src``, ``edge``, ``dst``
or their alias ``u``, ``e``, ``v``.
Returns
-------
......@@ -175,24 +196,33 @@ def gsddmm(g, op, u, v):
we expand its dimension with an additional dimension of length one. (e.g.
(90,) to (90, 1) for a graph with 90 nodes/edges).
"""
if u is not None:
if F.ndim(u) == 1:
u = F.unsqueeze(u, -1)
if v is not None:
if F.ndim(v) == 1:
v = F.unsqueeze(v, -1)
op = op_mapping[op]
ctx = F.context(u)
dtype = F.dtype(u)
u_shp = F.shape(u)
v_shp = F.shape(v) if v is not None else (0,)
e_shp = (g.number_of_edges(), ) +\
infer_broadcast_shape(op, u_shp[1:], v_shp[1:])
e = F.zeros(e_shp, dtype, ctx)
if g.number_of_edges() > 0:
gidx = g._graph.get_unitgraph(0, to_dgl_context(ctx))
_CAPI_DGLKernelSDDMM(gidx, op, to_dgl_nd(u), to_dgl_nd(v), to_dgl_nd(e))
return e
if gidx.number_of_etypes() != 1:
raise DGLError("We only support gsddmm on graph with one edge type")
use_lhs = op != 'copy_rhs'
use_rhs = op != 'copy_lhs'
if use_lhs:
if F.ndim(lhs) == 1:
lhs = F.unsqueeze(lhs, -1)
if use_rhs:
if F.ndim(rhs) == 1:
rhs = F.unsqueeze(rhs, -1)
lhs_target = target_mapping[lhs_target]
rhs_target = target_mapping[rhs_target]
ctx = F.context(lhs) if use_lhs else F.context(rhs)
dtype = F.dtype(lhs) if use_lhs else F.dtype(rhs)
lhs_shp = F.shape(lhs) if use_lhs else (0,)
rhs_shp = F.shape(rhs) if use_rhs else (0,)
out_shp = (gidx.number_of_edges(0), ) +\
infer_broadcast_shape(op, lhs_shp[1:], rhs_shp[1:])
out = F.zeros(out_shp, dtype, ctx)
if gidx.number_of_edges(0) > 0:
ugi = gidx.get_unitgraph(0, to_dgl_context(ctx))
_CAPI_DGLKernelSDDMM(ugi, op,
to_dgl_nd(lhs if use_lhs else None),
to_dgl_nd(rhs if use_rhs else None),
to_dgl_nd_for_write(out),
lhs_target, rhs_target)
return out
_init_api("dgl.sparse")
......@@ -9,7 +9,7 @@ from .graph import DGLGraph
from .heterograph import DGLHeteroGraph
from . import ndarray as nd
from . import backend as F
from .graph_index import from_coo, from_edge_list
from .graph_index import from_coo
from .graph_index import _get_halo_subgraph_inner_node
from .graph import unbatch
from .convert import graph, bipartite
......@@ -498,17 +498,8 @@ def reverse_heterograph(g, copy_ndata=True, copy_edata=False):
"""
# TODO(0.5 release, xiangsx) need to handle BLOCK
# currently reversing a block results in undefined behavior
canonical_etypes = g.canonical_etypes
meta_edges_src = []
meta_edges_dst = []
etypes = []
for c_etype in canonical_etypes:
meta_edges_src.append(g.get_ntype_id(c_etype[2]))
meta_edges_dst.append(g.get_ntype_id(c_etype[0]))
etypes.append(c_etype[1])
metagraph = from_edge_list((meta_edges_src, meta_edges_dst), True)
gidx = g._graph.reverse(metagraph)
new_g = DGLHeteroGraph(gidx, g.ntypes, etypes)
gidx = g._graph.reverse()
new_g = DGLHeteroGraph(gidx, g.ntypes, g.etypes)
# handle ndata
if copy_ndata:
......@@ -521,7 +512,7 @@ def reverse_heterograph(g, copy_ndata=True, copy_edata=False):
# handle edata
if copy_edata:
# for each etype
for etype in canonical_etypes:
for etype in g.etypes:
# for each data field
for k in g.edges[etype].data:
new_g.edges[etype].data[k] = g.edges[etype].data[k]
......
......@@ -9,57 +9,106 @@
namespace dgl {
namespace aten {
#define SWITCH_RHS(rhs_target, RhsTarget, ...) \
do { \
if ((rhs_target) == 0) { \
constexpr int RhsTarget = 0; \
{ __VA_ARGS__ } \
} else if ((rhs_target) == 1) { \
constexpr int RhsTarget = 1; \
{ __VA_ARGS__ } \
} else if ((rhs_target) == 2) { \
constexpr int RhsTarget = 2; \
{ __VA_ARGS__ } \
} else { \
LOG(INFO) << "Invalid rhs target: " << (rhs_target); \
} \
} while (0)
#define SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, ...)\
do { \
if ((lhs_target) == 0) { \
constexpr int LhsTarget = 0; \
SWITCH_RHS(rhs_target, RhsTarget, __VA_ARGS__); \
} else if ((lhs_target) == 1) { \
constexpr int LhsTarget = 1; \
SWITCH_RHS(rhs_target, RhsTarget, __VA_ARGS__); \
} else if ((lhs_target) == 2) { \
constexpr int LhsTarget = 2; \
SWITCH_RHS(rhs_target, RhsTarget, __VA_ARGS__); \
} else { \
LOG(INFO) << "Invalid lhs target: " << (lhs_target); \
} \
} while (0)
/*! \brief Generalized SDDMM on Csr format. */
template <int XPU, typename IdType, typename DType>
void SDDMMCsr(const std::string& op,
const BcastOff& bcast,
const CSRMatrix& csr,
NDArray ufeat,
NDArray vfeat,
NDArray out) {
NDArray lhs,
NDArray rhs,
NDArray out,
int lhs_target,
int rhs_target) {
SWITCH_OP(op, Op, {
cpu::SDDMMCsr<IdType, DType, Op>(bcast, csr, ufeat, vfeat, out);
SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
cpu::SDDMMCsr<IdType, DType, Op, LhsTarget, RhsTarget>(bcast, csr, lhs, rhs, out);
});
});
}
template void SDDMMCsr<kDLCPU, int32_t, float>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCsr<kDLCPU, int64_t, float>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCsr<kDLCPU, int32_t, double>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCsr<kDLCPU, int64_t, double>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
/*! \brief Generalized SDDMM on Coo format. */
template <int XPU, typename IdType, typename DType>
void SDDMMCoo(const std::string& op,
const BcastOff& bcast,
const COOMatrix& coo,
NDArray ufeat,
NDArray vfeat,
NDArray out) {
NDArray lhs,
NDArray rhs,
NDArray out,
int lhs_target,
int rhs_target) {
SWITCH_OP(op, Op, {
cpu::SDDMMCoo<IdType, DType, Op>(bcast, coo, ufeat, vfeat, out);
SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
cpu::SDDMMCoo<IdType, DType, Op, LhsTarget, RhsTarget>(bcast, coo, lhs, rhs, out);
});
});
}
template void SDDMMCoo<kDLCPU, int32_t, float>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out);
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCoo<kDLCPU, int64_t, float>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out);
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCoo<kDLCPU, int32_t, double>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out);
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCoo<kDLCPU, int64_t, double>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out);
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
} // namespace aten
} // namespace dgl
......@@ -8,6 +8,7 @@
#include <dgl/array.h>
#include <dgl/bcast.h>
#include "../selector.h"
namespace dgl {
namespace aten {
......@@ -17,22 +18,23 @@ namespace cpu {
* \brief CPU kernel of g-SDDMM on Csr format.
* \param bcast Broadcast information.
* \param csr The Csr matrix.
* \param ufeat The feature on source nodes.
* \param vfeat The feature on destination nodes.
* \param lhs The left hand side operand feature.
* \param rhs The right hand size operand feature.
* \param out The result feature on edges.
* \note it uses node parallel strategy, different threads are responsible
* for the computation of different nodes.
*/
template <typename IdType, typename DType, typename Op>
template <typename IdType, typename DType, typename Op,
int LhsTarget = 0, int RhsTarget = 2>
void SDDMMCsr(const BcastOff& bcast,
const CSRMatrix& csr,
NDArray ufeat, NDArray vfeat, NDArray out) {
NDArray lhs, NDArray rhs, NDArray out) {
const bool has_idx = !IsNullArray(csr.data);
const IdType* indptr = csr.indptr.Ptr<IdType>();
const IdType* indices = csr.indices.Ptr<IdType>();
const IdType* edges = csr.data.Ptr<IdType>();
const DType* X = ufeat.Ptr<DType>();
const DType* Y = vfeat.Ptr<DType>();
const DType* X = lhs.Ptr<DType>();
const DType* Y = rhs.Ptr<DType>();
const int64_t dim = bcast.out_len,
lhs_dim = bcast.lhs_len,
rhs_dim = bcast.rhs_len,
......@@ -48,10 +50,10 @@ void SDDMMCsr(const BcastOff& bcast,
for (int64_t k = 0; k < dim; ++k) {
const int64_t lhs_add = bcast.use_bcast ? bcast.lhs_offset[k] : k;
const int64_t rhs_add = bcast.use_bcast ? bcast.rhs_offset[k] : k;
const DType* lhs_off = Op::use_lhs? X + rid * lhs_dim +\
lhs_add * reduce_size : nullptr;
const DType* rhs_off = Op::use_rhs? Y + cid * rhs_dim +\
rhs_add * reduce_size : nullptr;
const DType* lhs_off = Op::use_lhs?
X + Selector<LhsTarget>::Call(rid, eid, cid) * lhs_dim + lhs_add * reduce_size : nullptr;
const DType* rhs_off = Op::use_rhs?
Y + Selector<RhsTarget>::Call(rid, eid, cid) * rhs_dim + rhs_add * reduce_size : nullptr;
out_off[k] = Op::Call(lhs_off, rhs_off, reduce_size);
}
}
......@@ -62,22 +64,23 @@ void SDDMMCsr(const BcastOff& bcast,
* \brief CPU kernel of g-SDDMM on Coo format.
* \param bcast Broadcast information.
* \param coo The COO matrix.
* \param ufeat The feature on source nodes.
* \param vfeat The feature on destination nodes.
* \param lhs The left hand side operand feature.
* \param rhs The right hand size operand feature.
* \param out The result feature on edges.
* \note it uses edge parallel strategy, different threads are responsible
* for the computation of different edges.
*/
template <typename IdType, typename DType, typename Op>
template <typename IdType, typename DType, typename Op,
int LhsTarget = 0, int RhsTarget = 2>
void SDDMMCoo(const BcastOff& bcast,
const COOMatrix& coo,
NDArray ufeat, NDArray vfeat, NDArray out) {
NDArray lhs, NDArray rhs, NDArray out) {
const bool has_idx = !IsNullArray(coo.data);
const IdType* row = coo.row.Ptr<IdType>();
const IdType* col = coo.col.Ptr<IdType>();
const IdType* edges = coo.data.Ptr<IdType>();
const DType* X = ufeat.Ptr<DType>();
const DType* Y = vfeat.Ptr<DType>();
const DType* X = lhs.Ptr<DType>();
const DType* Y = rhs.Ptr<DType>();
const int64_t dim = bcast.out_len,
lhs_dim = bcast.lhs_len,
rhs_dim = bcast.rhs_len,
......@@ -93,10 +96,10 @@ void SDDMMCoo(const BcastOff& bcast,
for (int64_t k = 0; k < dim; ++k) {
const int64_t lhs_add = bcast.use_bcast ? bcast.lhs_offset[k] : k;
const int64_t rhs_add = bcast.use_bcast ? bcast.rhs_offset[k] : k;
const DType* lhs_off = Op::use_lhs? X + rid * lhs_dim +\
lhs_add * reduce_size : nullptr;
const DType* rhs_off = Op::use_rhs? Y + cid * rhs_dim +\
rhs_add * reduce_size : nullptr;
const DType* lhs_off = Op::use_lhs ?
X + Selector<LhsTarget>::Call(rid, eid, cid) * lhs_dim + lhs_add * reduce_size : nullptr;
const DType* rhs_off = Op::use_rhs ?
Y + Selector<RhsTarget>::Call(rid, eid, cid) * rhs_dim + rhs_add * reduce_size : nullptr;
out_off[k] = Op::Call(lhs_off, rhs_off, bcast.reduce_size);
}
}
......@@ -186,10 +189,10 @@ struct Dot {
} else if ((op) == "div") { \
typedef dgl::aten::cpu::op::Div<DType> Op; \
{ __VA_ARGS__ } \
} else if ((op) == "copy_u") { \
} else if ((op) == "copy_lhs") { \
typedef dgl::aten::cpu::op::CopyLhs<DType> Op; \
{ __VA_ARGS__ } \
} else if ((op) == "copy_e") { \
} else if ((op) == "copy_rhs") { \
typedef dgl::aten::cpu::op::CopyRhs<DType> Op; \
{ __VA_ARGS__ } \
} else if ((op) == "dot") { \
......
......@@ -347,10 +347,10 @@ template <typename DType> constexpr DType Min<DType>::zero;
} else if ((op) == "div") { \
typedef dgl::aten::cpu::op::Div<DType> Op; \
{ __VA_ARGS__ } \
} else if ((op) == "copy_u") { \
} else if ((op) == "copy_lhs") { \
typedef dgl::aten::cpu::op::CopyLhs<DType> Op; \
{ __VA_ARGS__ } \
} else if ((op) == "copy_e") { \
} else if ((op) == "copy_rhs") { \
typedef dgl::aten::cpu::op::CopyRhs<DType> Op; \
{ __VA_ARGS__ } \
} else { \
......
......@@ -69,7 +69,7 @@ template <typename DType> constexpr bool Div<DType>::use_rhs;
template <typename DType> constexpr bool Div<DType>::reduce_last_dim;
template <typename DType>
struct CopyU {
struct CopyLhs {
static constexpr bool use_lhs = true;
static constexpr bool use_rhs = false;
static constexpr bool reduce_last_dim = false;
......@@ -78,12 +78,12 @@ struct CopyU {
return lhs[0];
}
};
template <typename DType> constexpr bool CopyU<DType>::use_lhs;
template <typename DType> constexpr bool CopyU<DType>::use_rhs;
template <typename DType> constexpr bool CopyU<DType>::reduce_last_dim;
template <typename DType> constexpr bool CopyLhs<DType>::use_lhs;
template <typename DType> constexpr bool CopyLhs<DType>::use_rhs;
template <typename DType> constexpr bool CopyLhs<DType>::reduce_last_dim;
template <typename DType>
struct CopyE {
struct CopyRhs {
static constexpr bool use_lhs = false;
static constexpr bool use_rhs = true;
static constexpr bool reduce_last_dim = false;
......@@ -92,9 +92,9 @@ struct CopyE {
return rhs[0];
}
};
template <typename DType> constexpr bool CopyE<DType>::use_lhs;
template <typename DType> constexpr bool CopyE<DType>::use_rhs;
template <typename DType> constexpr bool CopyE<DType>::reduce_last_dim;
template <typename DType> constexpr bool CopyRhs<DType>::use_lhs;
template <typename DType> constexpr bool CopyRhs<DType>::use_rhs;
template <typename DType> constexpr bool CopyRhs<DType>::reduce_last_dim;
template <typename DType>
struct Dot {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment