Unverified Commit 2fa2b453 authored by Zihao Ye's avatar Zihao Ye Committed by GitHub
Browse files

[Feature] Support higher order derivative for message passing. (#1877)

* upd

* fix typo
parent 2b8eb5be
......@@ -87,6 +87,7 @@ graph.
.. autosummary::
:toctree: ../../generated/
gspmm
u_add_e_sum
u_sub_e_sum
u_mul_e_sum
......@@ -193,6 +194,7 @@ The following is an example showing how GSDDMM works:
.. autosummary::
:toctree: ../../generated/
gsddmm
u_add_v
u_sub_v
u_mul_v
......
......@@ -101,7 +101,7 @@ a useful manual for in-depth developers.
api/python/graph
api/python/heterograph
api/python/backend
api/python/ops
api/python/readout
api/python/batch_heterograph
api/python/nn
......
......@@ -1377,7 +1377,7 @@ def copy_reduce(reducer, graph, target, in_data, out_size, in_map, out_map):
"""
pass
def gspmm(g, op, reduce_op, lhs_data, rhs_data):
def gspmm(gidx, op, reduce_op, lhs_data, rhs_data):
r""" Generalized Sparse Matrix Multiplication interface.
It fuses two steps into one kernel.
(1) Computes messages by :attr:`op` source node and edge features.
......@@ -1395,7 +1395,7 @@ def gspmm(g, op, reduce_op, lhs_data, rhs_data):
Parameters
----------
g : DGLHeteroGraph
gidx : HeteroGraphIndex
The input graph.
op : str
The binary op's name, could be ``add``, ``sub``, ``mul``, ``div``,
......@@ -1414,7 +1414,7 @@ def gspmm(g, op, reduce_op, lhs_data, rhs_data):
"""
pass
def gsddmm(g, op, lhs_data, rhs_data, lhs_target='u', rhs_target='v'):
def gsddmm(gidx, op, lhs_data, rhs_data, lhs_target='u', rhs_target='v'):
r""" Generalized Sampled-Dense-Dense Matrix Multiplication interface.
It computes edge features by :attr:`op` lhs features and rhs features.
......@@ -1428,7 +1428,7 @@ def gsddmm(g, op, lhs_data, rhs_data, lhs_target='u', rhs_target='v'):
Parameters
----------
g : DGLHeteroGraph
gidx : HeteroGraphIndex
The input graph.
op : str
Binary operator, could be ``add``, ``sub``, ``mul``, ``div``, ``dot``,
......
......@@ -3,7 +3,7 @@ import numpy as np
from mxnet import nd
from ...sparse import _gspmm, _gsddmm
from ...base import dgl_warning
from .tensor import asnumpy, copy_to, zerocopy_from_numpy, context
from .tensor import asnumpy, copy_to, zerocopy_from_numpy, context, to_backend_ctx
def _scatter_nd(index, src, n_rows):
assert index.shape == src.shape
......@@ -95,9 +95,9 @@ def _addsub(op, x):
return -x if op == 'sub' else x
class GSpMM(mx.autograd.Function):
def __init__(self, g, op, reduce_op):
def __init__(self, gidx, op, reduce_op):
super(GSpMM, self).__init__()
self.gidx = g._graph
self.gidx = gidx
self.op = op
self.reduce_op = reduce_op
......@@ -154,18 +154,19 @@ class GSpMM(mx.autograd.Function):
self.saved_tensors = None
return dX, dY
def gspmm(g, op, reduce_op, lhs_data, rhs_data):
func = GSpMM(g, op, reduce_op)
def gspmm(gidx, op, reduce_op, lhs_data, rhs_data):
func = GSpMM(gidx, op, reduce_op)
ctx = to_backend_ctx(gidx.ctx)
if lhs_data is None:
lhs_data = nd.zeros((1,), ctx=g.device)
lhs_data = nd.zeros((1,), ctx=ctx)
if rhs_data is None:
rhs_data = nd.zeros((1,), ctx=g.device)
rhs_data = nd.zeros((1,), ctx=ctx)
return func(lhs_data, rhs_data)
class GSDDMM(mx.autograd.Function):
def __init__(self, g, op, lhs_target, rhs_target):
def __init__(self, gidx, op, lhs_target, rhs_target):
super(GSDDMM, self).__init__()
self.gidx = g._graph
self.gidx = gidx
self.op = op
self.lhs_target = lhs_target
self.rhs_target = rhs_target
......@@ -225,10 +226,11 @@ class GSDDMM(mx.autograd.Function):
self.saved_tensors = None
return dX, dY
def gsddmm(g, op, lhs_data, rhs_data, lhs_target='u', rhs_target='v'):
func = GSDDMM(g, op, lhs_target, rhs_target)
def gsddmm(gidx, op, lhs_data, rhs_data, lhs_target='u', rhs_target='v'):
func = GSDDMM(gidx, op, lhs_target, rhs_target)
ctx = to_backend_ctx(gidx.ctx)
if lhs_data is None:
lhs_data = nd.zeros((1,), ctx=g.device)
lhs_data = nd.zeros((1,), ctx=ctx)
if rhs_data is None:
rhs_data = nd.zeros((1,), ctx=g.device)
rhs_data = nd.zeros((1,), ctx=ctx)
return func(lhs_data, rhs_data)
......@@ -50,8 +50,7 @@ def _addsub(op, x):
class GSpMM(th.autograd.Function):
@staticmethod
def forward(ctx, g, op, reduce_op, X, Y):
gidx = g._graph
def forward(ctx, gidx, op, reduce_op, X, Y):
out, (argX, argY) = _gspmm(gidx, op, reduce_op, X, Y)
ctx.backward_cache = gidx, op, reduce_op
ctx.save_for_backward(X, Y, argX, argY)
......@@ -65,11 +64,11 @@ class GSpMM(th.autograd.Function):
g_rev = gidx.reverse()
if reduce_op == 'sum':
if op in ['mul', 'div']:
dX = _gspmm(g_rev, 'mul', 'sum', dZ, _muldiv(op, Y))[0]
dX = gspmm(g_rev, 'mul', 'sum', dZ, _muldiv(op, Y))
elif op in ['add', 'sub']:
dX = _gspmm(g_rev, 'copy_lhs', 'sum', dZ, Y)[0]
dX = gspmm(g_rev, 'copy_lhs', 'sum', dZ, Y)
elif op == 'copy_lhs':
dX = _gspmm(g_rev, 'copy_lhs', 'sum', dZ, None)[0]
dX = gspmm(g_rev, 'copy_lhs', 'sum', dZ, None)
else:
dX = th.zeros((X.shape[0],) + dZ.shape[1:], dtype=X.dtype, device=X.device)
if op in ['mul', 'div']:
......@@ -83,12 +82,12 @@ class GSpMM(th.autograd.Function):
if op != 'copy_lhs' and ctx.needs_input_grad[4]:
if reduce_op == 'sum':
if op == 'mul' and _need_reduce_last_dim(X, Y):
dY = _gsddmm(gidx, 'dot', X, dZ)
dY = gsddmm(gidx, 'dot', X, dZ)
elif op in ['mul', 'div']:
dY = _gsddmm(gidx, 'mul', X, dZ)
dY = gsddmm(gidx, 'mul', X, dZ)
if op == 'div': dY = -dY / (Y ** 2)
elif op in ['add', 'sub', 'copy_rhs']:
dY = _gsddmm(gidx, 'copy_rhs', X, _addsub(op, dZ))
dY = gsddmm(gidx, 'copy_rhs', X, _addsub(op, dZ))
else:
dY = th.zeros((Y.shape[0],) + dZ.shape[1:], dtype=Y.dtype, device=Y.device)
if op in ['mul', 'div']:
......@@ -104,8 +103,7 @@ class GSpMM(th.autograd.Function):
class GSDDMM(th.autograd.Function):
@staticmethod
def forward(ctx, g, op, X, Y, lhs_target, rhs_target):
gidx = g._graph
def forward(ctx, gidx, op, X, Y, lhs_target, rhs_target):
out = _gsddmm(gidx, op, X, Y, lhs_target, rhs_target)
ctx.backward_cache = gidx, op, lhs_target, rhs_target
ctx.save_for_backward(X, Y)
......@@ -119,19 +117,19 @@ class GSDDMM(th.autograd.Function):
if lhs_target in ['u', 'v']:
_gidx = gidx if lhs_target == 'v' else gidx.reverse()
if op in ['add', 'sub', 'copy_lhs']:
dX = _gspmm(_gidx, 'copy_rhs', 'sum', None, dZ)[0]
dX = gspmm(_gidx, 'copy_rhs', 'sum', None, dZ)
else: # mul, div, dot
if rhs_target == lhs_target:
dX = _gspmm(_gidx, 'copy_rhs', 'sum', None, dZ)[0] * _muldiv(op, Y)
dX = gspmm(_gidx, 'copy_rhs', 'sum', None, dZ) * _muldiv(op, Y)
elif rhs_target == 'e':
dX = _gspmm(_gidx, 'copy_rhs', 'sum', None, dZ * _muldiv(op, Y))[0]
dX = gspmm(_gidx, 'copy_rhs', 'sum', None, dZ * _muldiv(op, Y))
else: # rhs_target = !lhs_target
dX = _gspmm(_gidx, 'mul', 'sum', _muldiv(op, Y), dZ)[0]
dX = gspmm(_gidx, 'mul', 'sum', _muldiv(op, Y), dZ)
else: # lhs_target == 'e'
if op in ['add', 'sub', 'copy_lhs']:
dX = dZ
else: # mul, div, dot
dX = _gsddmm(gidx, 'mul', dZ, _muldiv(op, Y), 'e', rhs_target)
dX = gsddmm(gidx, 'mul', dZ, _muldiv(op, Y), 'e', rhs_target)
dX = _reduce_grad(dX, X.shape)
else:
dX = None
......@@ -139,29 +137,31 @@ class GSDDMM(th.autograd.Function):
if rhs_target in ['u', 'v']:
_gidx = gidx if rhs_target == 'v' else gidx.reverse()
if op in ['add', 'sub', 'copy_rhs']:
dY = _gspmm(_gidx, 'copy_rhs', 'sum', None, _addsub(op, dZ))[0]
dY = gspmm(_gidx, 'copy_rhs', 'sum', None, _addsub(op, dZ))
else: # mul, div, dot
if lhs_target == rhs_target:
dY = _gspmm(_gidx, 'copy_rhs', 'sum', None, dZ)[0] * X
dY = gspmm(_gidx, 'copy_rhs', 'sum', None, dZ) * X
elif lhs_target == 'e':
dY = _gspmm(_gidx, 'copy_rhs', 'sum', None, dZ * X)[0]
dY = gspmm(_gidx, 'copy_rhs', 'sum', None, dZ * X)
else: # rhs_target = !lhs_target
dY = _gspmm(_gidx, 'mul', 'sum', X, dZ)[0]
if op == 'div': dY = -dY / (Y ** 2)
dY = gspmm(_gidx, 'mul', 'sum', X, dZ)
if op == 'div':
dY = -dY / (Y ** 2)
else:
if op in ['add', 'sub', 'copy_rhs']:
dY = _addsub(op, dZ)
else: # mul, div, dot
dY = _gsddmm(gidx, 'mul', dZ, X, 'e', lhs_target)
if op == 'div': dY = -dY / (Y ** 2)
dY = gsddmm(gidx, 'mul', dZ, X, 'e', lhs_target)
if op == 'div':
dY = -dY / (Y ** 2)
dY = _reduce_grad(dY, Y.shape)
else:
dY = None
return None, None, dX, dY, None, None
def gspmm(g, op, reduce_op, lhs_data, rhs_data):
return GSpMM.apply(g, op, reduce_op, lhs_data, rhs_data)
def gspmm(gidx, op, reduce_op, lhs_data, rhs_data):
return GSpMM.apply(gidx, op, reduce_op, lhs_data, rhs_data)
def gsddmm(g, op, lhs_data, rhs_data, lhs_target='u', rhs_target='v'):
return GSDDMM.apply(g, op, lhs_data, rhs_data, lhs_target, rhs_target)
def gsddmm(gidx, op, lhs_data, rhs_data, lhs_target='u', rhs_target='v'):
return GSDDMM.apply(gidx, op, lhs_data, rhs_data, lhs_target, rhs_target)
......@@ -85,8 +85,7 @@ def _muldiv(op, x):
def _addsub(op, x):
return -x if op == 'sub' else x
def gspmm_real(g, op, reduce_op, X, Y):
gidx = g._graph
def gspmm_real(gidx, op, reduce_op, X, Y):
out, (argX, argY) = _gspmm(gidx, op, reduce_op, X, Y)
def grad(dZ):
......@@ -136,18 +135,17 @@ def gspmm_real(g, op, reduce_op, X, Y):
return dX, dY
return out, grad
def gspmm(g, op, reduce_op, X, Y):
def gspmm(gidx, op, reduce_op, X, Y):
@tf.custom_gradient
def _lambda(X, Y):
return gspmm_real(g, op, reduce_op, X, Y)
return gspmm_real(gidx, op, reduce_op, X, Y)
if X is None:
X = tf.zeros(())
if Y is None:
Y = tf.zeros(())
return _lambda(X, Y)
def gsddmm_real(g, op, X, Y, lhs_target, rhs_target):
gidx = g._graph
def gsddmm_real(gidx, op, X, Y, lhs_target, rhs_target):
out = _gsddmm(gidx, op, X, Y, lhs_target, rhs_target)
def grad(dZ):
......@@ -196,10 +194,10 @@ def gsddmm_real(g, op, X, Y, lhs_target, rhs_target):
return dX, dY
return out, grad
def gsddmm(g, op, X, Y, lhs_target='u', rhs_target='v'):
def gsddmm(gidx, op, X, Y, lhs_target='u', rhs_target='v'):
@tf.custom_gradient
def _lambda(X, Y):
return gsddmm_real(g, op, X, Y, lhs_target, rhs_target)
return gsddmm_real(gidx, op, X, Y, lhs_target, rhs_target)
if X is None:
X = tf.zeros(())
if Y is None:
......
......@@ -2,10 +2,45 @@
from itertools import product
import sys
from ..backend import gsddmm
from ..backend import gsddmm as gsddmm_internal
__all__ = ['gsddmm', 'copy_u', 'copy_v']
def gsddmm(g, op, lhs_data, rhs_data, lhs_target='u', rhs_target='v'):
r""" Generalized Sampled-Dense-Dense Matrix Multiplication interface.
It computes edge features by :attr:`op` lhs features and rhs features.
.. math::
x_{e} = \phi(x_{lhs}, x_{rhs}), \forall (u,e,v)\in \mathcal{G}
where :math:`x_{e}` is the returned feature on edges and :math:`x_u`,
:math:`x_v` refers to :attr:`u`, :attr:`v` respectively. :math:`\phi`
is the binary operator :attr:`op`, and :math:`\mathcal{G}` is the graph
we apply gsddmm on: :attr:`g`. $lhs$ and $rhs$ are one of $u,v,e$'s.
Parameters
----------
g : DGLGraph
The input graph.
op : str
Binary operator, could be ``add``, ``sub``, ``mul``, ``div``, ``dot``,
``copy_lhs``, ``copy_rhs``.
lhs_data : tensor or None
The left operand, could be None if it's not required by op.
rhs_data : tensor or None
The right operand, could be None if it's not required by op.
lhs_target: str
Choice of `u`(source), `e`(edge) or `v`(destination) for left operand.
rhs_target: str
Choice of `u`(source), `e`(edge) or `v`(destination) for right operand.
Returns
-------
tensor
The result tensor.
"""
return gsddmm_internal(
g._graph, op, lhs_data, rhs_data, lhs_target, rhs_target)
def _gen_sddmm_func(lhs_target, rhs_target, binary_op):
name = "{}_{}_{}".format(lhs_target, binary_op, rhs_target)
......
"""dgl spmm operator module."""
import sys
from ..backend import gspmm
from ..backend import gspmm as gspmm_internal
__all__ = ['gspmm']
def gspmm(g, op, reduce_op, lhs_data, rhs_data):
r""" Generalized Sparse Matrix Multiplication interface.
It fuses two steps into one kernel.
(1) Computes messages by :attr:`op` source node and edge features.
(2) Aggregate the messages by :attr:`reduce_op` as the features on destination nodes.
.. math::
x_v = \psi_{(u, v, e)\in \mathcal{G}}(\rho(x_u, x_e))
where :math:`x_v` is the returned feature on destination nodes, and :math`x_u`,
:math:`x_e` refers to :attr:`u`, :attr:`e` respectively. :math:`\rho` means binary
operator :attr:`op` and :math:`\psi` means reduce operator :attr:`reduce_op`,
:math:`\mathcal{G}` is the graph we apply gspmm on: :attr:`g`.
Note that this function does not handle gradients.
Parameters
----------
g : DGLGraph
The input graph.
op : str
The binary op's name, could be ``add``, ``sub``, ``mul``, ``div``,
``copy_lhs``, ``copy_rhs``.
reduce_op : str
Reduce operator, could be ``sum``, ``max``, ``min``.
lhs_data : tensor or None
The left operand, could be None if it's not required by the op.
rhs_data : tensor or None
The right operand, could be None if it's not required by the op.
Returns
-------
tensor
The result tensor.
"""
return gspmm_internal(g._graph, op, reduce_op, lhs_data, rhs_data)
def _gen_spmm_func(binary_op, reduce_op):
name = "u_{}_e_{}".format(binary_op, reduce_op)
docstring = """Generalized SpMM function.
......
......@@ -109,9 +109,7 @@ def _gspmm(gidx, op, reduce_op, u, e):
Notes
-----
This function does not handle gradients, and for scalar input features,
we expand its dimension with an additional dimension of length one. (e.g.
(90,) to (90, 1) for a graph with 90 nodes/edges).
This function does not handle gradients.
"""
if gidx.number_of_etypes() != 1:
raise DGLError("We only support gspmm on graph with one edge type")
......@@ -192,9 +190,7 @@ def _gsddmm(gidx, op, lhs, rhs, lhs_target='u', rhs_target='v'):
Notes
-----
This function does not handle gradients, and for scalar input features,
we expand its dimension with an additional dimension of length one. (e.g.
(90,) to (90, 1) for a graph with 90 nodes/edges).
This function does not handle gradients.
"""
if gidx.number_of_etypes() != 1:
raise DGLError("We only support gsddmm on graph with one edge type")
......
from dgl.backend import gspmm, gsddmm
from dgl.ops import gspmm, gsddmm
from utils import parametrize_dtype
import dgl
import random
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment