Unverified Commit 2fa2b453 authored by Zihao Ye's avatar Zihao Ye Committed by GitHub
Browse files

[Feature] Support higher order derivative for message passing. (#1877)

* upd

* fix typo
parent 2b8eb5be
...@@ -87,6 +87,7 @@ graph. ...@@ -87,6 +87,7 @@ graph.
.. autosummary:: .. autosummary::
:toctree: ../../generated/ :toctree: ../../generated/
gspmm
u_add_e_sum u_add_e_sum
u_sub_e_sum u_sub_e_sum
u_mul_e_sum u_mul_e_sum
...@@ -193,6 +194,7 @@ The following is an example showing how GSDDMM works: ...@@ -193,6 +194,7 @@ The following is an example showing how GSDDMM works:
.. autosummary:: .. autosummary::
:toctree: ../../generated/ :toctree: ../../generated/
gsddmm
u_add_v u_add_v
u_sub_v u_sub_v
u_mul_v u_mul_v
......
...@@ -101,7 +101,7 @@ a useful manual for in-depth developers. ...@@ -101,7 +101,7 @@ a useful manual for in-depth developers.
api/python/graph api/python/graph
api/python/heterograph api/python/heterograph
api/python/backend api/python/ops
api/python/readout api/python/readout
api/python/batch_heterograph api/python/batch_heterograph
api/python/nn api/python/nn
......
...@@ -1377,7 +1377,7 @@ def copy_reduce(reducer, graph, target, in_data, out_size, in_map, out_map): ...@@ -1377,7 +1377,7 @@ def copy_reduce(reducer, graph, target, in_data, out_size, in_map, out_map):
""" """
pass pass
def gspmm(g, op, reduce_op, lhs_data, rhs_data): def gspmm(gidx, op, reduce_op, lhs_data, rhs_data):
r""" Generalized Sparse Matrix Multiplication interface. r""" Generalized Sparse Matrix Multiplication interface.
It fuses two steps into one kernel. It fuses two steps into one kernel.
(1) Computes messages by :attr:`op` source node and edge features. (1) Computes messages by :attr:`op` source node and edge features.
...@@ -1395,7 +1395,7 @@ def gspmm(g, op, reduce_op, lhs_data, rhs_data): ...@@ -1395,7 +1395,7 @@ def gspmm(g, op, reduce_op, lhs_data, rhs_data):
Parameters Parameters
---------- ----------
g : DGLHeteroGraph gidx : HeteroGraphIndex
The input graph. The input graph.
op : str op : str
The binary op's name, could be ``add``, ``sub``, ``mul``, ``div``, The binary op's name, could be ``add``, ``sub``, ``mul``, ``div``,
...@@ -1414,7 +1414,7 @@ def gspmm(g, op, reduce_op, lhs_data, rhs_data): ...@@ -1414,7 +1414,7 @@ def gspmm(g, op, reduce_op, lhs_data, rhs_data):
""" """
pass pass
def gsddmm(g, op, lhs_data, rhs_data, lhs_target='u', rhs_target='v'): def gsddmm(gidx, op, lhs_data, rhs_data, lhs_target='u', rhs_target='v'):
r""" Generalized Sampled-Dense-Dense Matrix Multiplication interface. r""" Generalized Sampled-Dense-Dense Matrix Multiplication interface.
It computes edge features by :attr:`op` lhs features and rhs features. It computes edge features by :attr:`op` lhs features and rhs features.
...@@ -1428,7 +1428,7 @@ def gsddmm(g, op, lhs_data, rhs_data, lhs_target='u', rhs_target='v'): ...@@ -1428,7 +1428,7 @@ def gsddmm(g, op, lhs_data, rhs_data, lhs_target='u', rhs_target='v'):
Parameters Parameters
---------- ----------
g : DGLHeteroGraph gidx : HeteroGraphIndex
The input graph. The input graph.
op : str op : str
Binary operator, could be ``add``, ``sub``, ``mul``, ``div``, ``dot``, Binary operator, could be ``add``, ``sub``, ``mul``, ``div``, ``dot``,
......
...@@ -3,7 +3,7 @@ import numpy as np ...@@ -3,7 +3,7 @@ import numpy as np
from mxnet import nd from mxnet import nd
from ...sparse import _gspmm, _gsddmm from ...sparse import _gspmm, _gsddmm
from ...base import dgl_warning from ...base import dgl_warning
from .tensor import asnumpy, copy_to, zerocopy_from_numpy, context from .tensor import asnumpy, copy_to, zerocopy_from_numpy, context, to_backend_ctx
def _scatter_nd(index, src, n_rows): def _scatter_nd(index, src, n_rows):
assert index.shape == src.shape assert index.shape == src.shape
...@@ -95,9 +95,9 @@ def _addsub(op, x): ...@@ -95,9 +95,9 @@ def _addsub(op, x):
return -x if op == 'sub' else x return -x if op == 'sub' else x
class GSpMM(mx.autograd.Function): class GSpMM(mx.autograd.Function):
def __init__(self, g, op, reduce_op): def __init__(self, gidx, op, reduce_op):
super(GSpMM, self).__init__() super(GSpMM, self).__init__()
self.gidx = g._graph self.gidx = gidx
self.op = op self.op = op
self.reduce_op = reduce_op self.reduce_op = reduce_op
...@@ -154,18 +154,19 @@ class GSpMM(mx.autograd.Function): ...@@ -154,18 +154,19 @@ class GSpMM(mx.autograd.Function):
self.saved_tensors = None self.saved_tensors = None
return dX, dY return dX, dY
def gspmm(g, op, reduce_op, lhs_data, rhs_data): def gspmm(gidx, op, reduce_op, lhs_data, rhs_data):
func = GSpMM(g, op, reduce_op) func = GSpMM(gidx, op, reduce_op)
ctx = to_backend_ctx(gidx.ctx)
if lhs_data is None: if lhs_data is None:
lhs_data = nd.zeros((1,), ctx=g.device) lhs_data = nd.zeros((1,), ctx=ctx)
if rhs_data is None: if rhs_data is None:
rhs_data = nd.zeros((1,), ctx=g.device) rhs_data = nd.zeros((1,), ctx=ctx)
return func(lhs_data, rhs_data) return func(lhs_data, rhs_data)
class GSDDMM(mx.autograd.Function): class GSDDMM(mx.autograd.Function):
def __init__(self, g, op, lhs_target, rhs_target): def __init__(self, gidx, op, lhs_target, rhs_target):
super(GSDDMM, self).__init__() super(GSDDMM, self).__init__()
self.gidx = g._graph self.gidx = gidx
self.op = op self.op = op
self.lhs_target = lhs_target self.lhs_target = lhs_target
self.rhs_target = rhs_target self.rhs_target = rhs_target
...@@ -225,10 +226,11 @@ class GSDDMM(mx.autograd.Function): ...@@ -225,10 +226,11 @@ class GSDDMM(mx.autograd.Function):
self.saved_tensors = None self.saved_tensors = None
return dX, dY return dX, dY
def gsddmm(g, op, lhs_data, rhs_data, lhs_target='u', rhs_target='v'): def gsddmm(gidx, op, lhs_data, rhs_data, lhs_target='u', rhs_target='v'):
func = GSDDMM(g, op, lhs_target, rhs_target) func = GSDDMM(gidx, op, lhs_target, rhs_target)
ctx = to_backend_ctx(gidx.ctx)
if lhs_data is None: if lhs_data is None:
lhs_data = nd.zeros((1,), ctx=g.device) lhs_data = nd.zeros((1,), ctx=ctx)
if rhs_data is None: if rhs_data is None:
rhs_data = nd.zeros((1,), ctx=g.device) rhs_data = nd.zeros((1,), ctx=ctx)
return func(lhs_data, rhs_data) return func(lhs_data, rhs_data)
...@@ -50,8 +50,7 @@ def _addsub(op, x): ...@@ -50,8 +50,7 @@ def _addsub(op, x):
class GSpMM(th.autograd.Function): class GSpMM(th.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, g, op, reduce_op, X, Y): def forward(ctx, gidx, op, reduce_op, X, Y):
gidx = g._graph
out, (argX, argY) = _gspmm(gidx, op, reduce_op, X, Y) out, (argX, argY) = _gspmm(gidx, op, reduce_op, X, Y)
ctx.backward_cache = gidx, op, reduce_op ctx.backward_cache = gidx, op, reduce_op
ctx.save_for_backward(X, Y, argX, argY) ctx.save_for_backward(X, Y, argX, argY)
...@@ -65,11 +64,11 @@ class GSpMM(th.autograd.Function): ...@@ -65,11 +64,11 @@ class GSpMM(th.autograd.Function):
g_rev = gidx.reverse() g_rev = gidx.reverse()
if reduce_op == 'sum': if reduce_op == 'sum':
if op in ['mul', 'div']: if op in ['mul', 'div']:
dX = _gspmm(g_rev, 'mul', 'sum', dZ, _muldiv(op, Y))[0] dX = gspmm(g_rev, 'mul', 'sum', dZ, _muldiv(op, Y))
elif op in ['add', 'sub']: elif op in ['add', 'sub']:
dX = _gspmm(g_rev, 'copy_lhs', 'sum', dZ, Y)[0] dX = gspmm(g_rev, 'copy_lhs', 'sum', dZ, Y)
elif op == 'copy_lhs': elif op == 'copy_lhs':
dX = _gspmm(g_rev, 'copy_lhs', 'sum', dZ, None)[0] dX = gspmm(g_rev, 'copy_lhs', 'sum', dZ, None)
else: else:
dX = th.zeros((X.shape[0],) + dZ.shape[1:], dtype=X.dtype, device=X.device) dX = th.zeros((X.shape[0],) + dZ.shape[1:], dtype=X.dtype, device=X.device)
if op in ['mul', 'div']: if op in ['mul', 'div']:
...@@ -83,12 +82,12 @@ class GSpMM(th.autograd.Function): ...@@ -83,12 +82,12 @@ class GSpMM(th.autograd.Function):
if op != 'copy_lhs' and ctx.needs_input_grad[4]: if op != 'copy_lhs' and ctx.needs_input_grad[4]:
if reduce_op == 'sum': if reduce_op == 'sum':
if op == 'mul' and _need_reduce_last_dim(X, Y): if op == 'mul' and _need_reduce_last_dim(X, Y):
dY = _gsddmm(gidx, 'dot', X, dZ) dY = gsddmm(gidx, 'dot', X, dZ)
elif op in ['mul', 'div']: elif op in ['mul', 'div']:
dY = _gsddmm(gidx, 'mul', X, dZ) dY = gsddmm(gidx, 'mul', X, dZ)
if op == 'div': dY = -dY / (Y ** 2) if op == 'div': dY = -dY / (Y ** 2)
elif op in ['add', 'sub', 'copy_rhs']: elif op in ['add', 'sub', 'copy_rhs']:
dY = _gsddmm(gidx, 'copy_rhs', X, _addsub(op, dZ)) dY = gsddmm(gidx, 'copy_rhs', X, _addsub(op, dZ))
else: else:
dY = th.zeros((Y.shape[0],) + dZ.shape[1:], dtype=Y.dtype, device=Y.device) dY = th.zeros((Y.shape[0],) + dZ.shape[1:], dtype=Y.dtype, device=Y.device)
if op in ['mul', 'div']: if op in ['mul', 'div']:
...@@ -104,8 +103,7 @@ class GSpMM(th.autograd.Function): ...@@ -104,8 +103,7 @@ class GSpMM(th.autograd.Function):
class GSDDMM(th.autograd.Function): class GSDDMM(th.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, g, op, X, Y, lhs_target, rhs_target): def forward(ctx, gidx, op, X, Y, lhs_target, rhs_target):
gidx = g._graph
out = _gsddmm(gidx, op, X, Y, lhs_target, rhs_target) out = _gsddmm(gidx, op, X, Y, lhs_target, rhs_target)
ctx.backward_cache = gidx, op, lhs_target, rhs_target ctx.backward_cache = gidx, op, lhs_target, rhs_target
ctx.save_for_backward(X, Y) ctx.save_for_backward(X, Y)
...@@ -119,19 +117,19 @@ class GSDDMM(th.autograd.Function): ...@@ -119,19 +117,19 @@ class GSDDMM(th.autograd.Function):
if lhs_target in ['u', 'v']: if lhs_target in ['u', 'v']:
_gidx = gidx if lhs_target == 'v' else gidx.reverse() _gidx = gidx if lhs_target == 'v' else gidx.reverse()
if op in ['add', 'sub', 'copy_lhs']: if op in ['add', 'sub', 'copy_lhs']:
dX = _gspmm(_gidx, 'copy_rhs', 'sum', None, dZ)[0] dX = gspmm(_gidx, 'copy_rhs', 'sum', None, dZ)
else: # mul, div, dot else: # mul, div, dot
if rhs_target == lhs_target: if rhs_target == lhs_target:
dX = _gspmm(_gidx, 'copy_rhs', 'sum', None, dZ)[0] * _muldiv(op, Y) dX = gspmm(_gidx, 'copy_rhs', 'sum', None, dZ) * _muldiv(op, Y)
elif rhs_target == 'e': elif rhs_target == 'e':
dX = _gspmm(_gidx, 'copy_rhs', 'sum', None, dZ * _muldiv(op, Y))[0] dX = gspmm(_gidx, 'copy_rhs', 'sum', None, dZ * _muldiv(op, Y))
else: # rhs_target = !lhs_target else: # rhs_target = !lhs_target
dX = _gspmm(_gidx, 'mul', 'sum', _muldiv(op, Y), dZ)[0] dX = gspmm(_gidx, 'mul', 'sum', _muldiv(op, Y), dZ)
else: # lhs_target == 'e' else: # lhs_target == 'e'
if op in ['add', 'sub', 'copy_lhs']: if op in ['add', 'sub', 'copy_lhs']:
dX = dZ dX = dZ
else: # mul, div, dot else: # mul, div, dot
dX = _gsddmm(gidx, 'mul', dZ, _muldiv(op, Y), 'e', rhs_target) dX = gsddmm(gidx, 'mul', dZ, _muldiv(op, Y), 'e', rhs_target)
dX = _reduce_grad(dX, X.shape) dX = _reduce_grad(dX, X.shape)
else: else:
dX = None dX = None
...@@ -139,29 +137,31 @@ class GSDDMM(th.autograd.Function): ...@@ -139,29 +137,31 @@ class GSDDMM(th.autograd.Function):
if rhs_target in ['u', 'v']: if rhs_target in ['u', 'v']:
_gidx = gidx if rhs_target == 'v' else gidx.reverse() _gidx = gidx if rhs_target == 'v' else gidx.reverse()
if op in ['add', 'sub', 'copy_rhs']: if op in ['add', 'sub', 'copy_rhs']:
dY = _gspmm(_gidx, 'copy_rhs', 'sum', None, _addsub(op, dZ))[0] dY = gspmm(_gidx, 'copy_rhs', 'sum', None, _addsub(op, dZ))
else: # mul, div, dot else: # mul, div, dot
if lhs_target == rhs_target: if lhs_target == rhs_target:
dY = _gspmm(_gidx, 'copy_rhs', 'sum', None, dZ)[0] * X dY = gspmm(_gidx, 'copy_rhs', 'sum', None, dZ) * X
elif lhs_target == 'e': elif lhs_target == 'e':
dY = _gspmm(_gidx, 'copy_rhs', 'sum', None, dZ * X)[0] dY = gspmm(_gidx, 'copy_rhs', 'sum', None, dZ * X)
else: # rhs_target = !lhs_target else: # rhs_target = !lhs_target
dY = _gspmm(_gidx, 'mul', 'sum', X, dZ)[0] dY = gspmm(_gidx, 'mul', 'sum', X, dZ)
if op == 'div': dY = -dY / (Y ** 2) if op == 'div':
dY = -dY / (Y ** 2)
else: else:
if op in ['add', 'sub', 'copy_rhs']: if op in ['add', 'sub', 'copy_rhs']:
dY = _addsub(op, dZ) dY = _addsub(op, dZ)
else: # mul, div, dot else: # mul, div, dot
dY = _gsddmm(gidx, 'mul', dZ, X, 'e', lhs_target) dY = gsddmm(gidx, 'mul', dZ, X, 'e', lhs_target)
if op == 'div': dY = -dY / (Y ** 2) if op == 'div':
dY = -dY / (Y ** 2)
dY = _reduce_grad(dY, Y.shape) dY = _reduce_grad(dY, Y.shape)
else: else:
dY = None dY = None
return None, None, dX, dY, None, None return None, None, dX, dY, None, None
def gspmm(g, op, reduce_op, lhs_data, rhs_data): def gspmm(gidx, op, reduce_op, lhs_data, rhs_data):
return GSpMM.apply(g, op, reduce_op, lhs_data, rhs_data) return GSpMM.apply(gidx, op, reduce_op, lhs_data, rhs_data)
def gsddmm(g, op, lhs_data, rhs_data, lhs_target='u', rhs_target='v'): def gsddmm(gidx, op, lhs_data, rhs_data, lhs_target='u', rhs_target='v'):
return GSDDMM.apply(g, op, lhs_data, rhs_data, lhs_target, rhs_target) return GSDDMM.apply(gidx, op, lhs_data, rhs_data, lhs_target, rhs_target)
...@@ -85,8 +85,7 @@ def _muldiv(op, x): ...@@ -85,8 +85,7 @@ def _muldiv(op, x):
def _addsub(op, x): def _addsub(op, x):
return -x if op == 'sub' else x return -x if op == 'sub' else x
def gspmm_real(g, op, reduce_op, X, Y): def gspmm_real(gidx, op, reduce_op, X, Y):
gidx = g._graph
out, (argX, argY) = _gspmm(gidx, op, reduce_op, X, Y) out, (argX, argY) = _gspmm(gidx, op, reduce_op, X, Y)
def grad(dZ): def grad(dZ):
...@@ -136,18 +135,17 @@ def gspmm_real(g, op, reduce_op, X, Y): ...@@ -136,18 +135,17 @@ def gspmm_real(g, op, reduce_op, X, Y):
return dX, dY return dX, dY
return out, grad return out, grad
def gspmm(g, op, reduce_op, X, Y): def gspmm(gidx, op, reduce_op, X, Y):
@tf.custom_gradient @tf.custom_gradient
def _lambda(X, Y): def _lambda(X, Y):
return gspmm_real(g, op, reduce_op, X, Y) return gspmm_real(gidx, op, reduce_op, X, Y)
if X is None: if X is None:
X = tf.zeros(()) X = tf.zeros(())
if Y is None: if Y is None:
Y = tf.zeros(()) Y = tf.zeros(())
return _lambda(X, Y) return _lambda(X, Y)
def gsddmm_real(g, op, X, Y, lhs_target, rhs_target): def gsddmm_real(gidx, op, X, Y, lhs_target, rhs_target):
gidx = g._graph
out = _gsddmm(gidx, op, X, Y, lhs_target, rhs_target) out = _gsddmm(gidx, op, X, Y, lhs_target, rhs_target)
def grad(dZ): def grad(dZ):
...@@ -196,10 +194,10 @@ def gsddmm_real(g, op, X, Y, lhs_target, rhs_target): ...@@ -196,10 +194,10 @@ def gsddmm_real(g, op, X, Y, lhs_target, rhs_target):
return dX, dY return dX, dY
return out, grad return out, grad
def gsddmm(g, op, X, Y, lhs_target='u', rhs_target='v'): def gsddmm(gidx, op, X, Y, lhs_target='u', rhs_target='v'):
@tf.custom_gradient @tf.custom_gradient
def _lambda(X, Y): def _lambda(X, Y):
return gsddmm_real(g, op, X, Y, lhs_target, rhs_target) return gsddmm_real(gidx, op, X, Y, lhs_target, rhs_target)
if X is None: if X is None:
X = tf.zeros(()) X = tf.zeros(())
if Y is None: if Y is None:
......
...@@ -2,10 +2,45 @@ ...@@ -2,10 +2,45 @@
from itertools import product from itertools import product
import sys import sys
from ..backend import gsddmm from ..backend import gsddmm as gsddmm_internal
__all__ = ['gsddmm', 'copy_u', 'copy_v'] __all__ = ['gsddmm', 'copy_u', 'copy_v']
def gsddmm(g, op, lhs_data, rhs_data, lhs_target='u', rhs_target='v'):
r""" Generalized Sampled-Dense-Dense Matrix Multiplication interface.
It computes edge features by :attr:`op` lhs features and rhs features.
.. math::
x_{e} = \phi(x_{lhs}, x_{rhs}), \forall (u,e,v)\in \mathcal{G}
where :math:`x_{e}` is the returned feature on edges and :math:`x_u`,
:math:`x_v` refers to :attr:`u`, :attr:`v` respectively. :math:`\phi`
is the binary operator :attr:`op`, and :math:`\mathcal{G}` is the graph
we apply gsddmm on: :attr:`g`. $lhs$ and $rhs$ are one of $u,v,e$'s.
Parameters
----------
g : DGLGraph
The input graph.
op : str
Binary operator, could be ``add``, ``sub``, ``mul``, ``div``, ``dot``,
``copy_lhs``, ``copy_rhs``.
lhs_data : tensor or None
The left operand, could be None if it's not required by op.
rhs_data : tensor or None
The right operand, could be None if it's not required by op.
lhs_target: str
Choice of `u`(source), `e`(edge) or `v`(destination) for left operand.
rhs_target: str
Choice of `u`(source), `e`(edge) or `v`(destination) for right operand.
Returns
-------
tensor
The result tensor.
"""
return gsddmm_internal(
g._graph, op, lhs_data, rhs_data, lhs_target, rhs_target)
def _gen_sddmm_func(lhs_target, rhs_target, binary_op): def _gen_sddmm_func(lhs_target, rhs_target, binary_op):
name = "{}_{}_{}".format(lhs_target, binary_op, rhs_target) name = "{}_{}_{}".format(lhs_target, binary_op, rhs_target)
......
"""dgl spmm operator module.""" """dgl spmm operator module."""
import sys import sys
from ..backend import gspmm from ..backend import gspmm as gspmm_internal
__all__ = ['gspmm'] __all__ = ['gspmm']
def gspmm(g, op, reduce_op, lhs_data, rhs_data):
r""" Generalized Sparse Matrix Multiplication interface.
It fuses two steps into one kernel.
(1) Computes messages by :attr:`op` source node and edge features.
(2) Aggregate the messages by :attr:`reduce_op` as the features on destination nodes.
.. math::
x_v = \psi_{(u, v, e)\in \mathcal{G}}(\rho(x_u, x_e))
where :math:`x_v` is the returned feature on destination nodes, and :math`x_u`,
:math:`x_e` refers to :attr:`u`, :attr:`e` respectively. :math:`\rho` means binary
operator :attr:`op` and :math:`\psi` means reduce operator :attr:`reduce_op`,
:math:`\mathcal{G}` is the graph we apply gspmm on: :attr:`g`.
Note that this function does not handle gradients.
Parameters
----------
g : DGLGraph
The input graph.
op : str
The binary op's name, could be ``add``, ``sub``, ``mul``, ``div``,
``copy_lhs``, ``copy_rhs``.
reduce_op : str
Reduce operator, could be ``sum``, ``max``, ``min``.
lhs_data : tensor or None
The left operand, could be None if it's not required by the op.
rhs_data : tensor or None
The right operand, could be None if it's not required by the op.
Returns
-------
tensor
The result tensor.
"""
return gspmm_internal(g._graph, op, reduce_op, lhs_data, rhs_data)
def _gen_spmm_func(binary_op, reduce_op): def _gen_spmm_func(binary_op, reduce_op):
name = "u_{}_e_{}".format(binary_op, reduce_op) name = "u_{}_e_{}".format(binary_op, reduce_op)
docstring = """Generalized SpMM function. docstring = """Generalized SpMM function.
......
...@@ -109,9 +109,7 @@ def _gspmm(gidx, op, reduce_op, u, e): ...@@ -109,9 +109,7 @@ def _gspmm(gidx, op, reduce_op, u, e):
Notes Notes
----- -----
This function does not handle gradients, and for scalar input features, This function does not handle gradients.
we expand its dimension with an additional dimension of length one. (e.g.
(90,) to (90, 1) for a graph with 90 nodes/edges).
""" """
if gidx.number_of_etypes() != 1: if gidx.number_of_etypes() != 1:
raise DGLError("We only support gspmm on graph with one edge type") raise DGLError("We only support gspmm on graph with one edge type")
...@@ -192,9 +190,7 @@ def _gsddmm(gidx, op, lhs, rhs, lhs_target='u', rhs_target='v'): ...@@ -192,9 +190,7 @@ def _gsddmm(gidx, op, lhs, rhs, lhs_target='u', rhs_target='v'):
Notes Notes
----- -----
This function does not handle gradients, and for scalar input features, This function does not handle gradients.
we expand its dimension with an additional dimension of length one. (e.g.
(90,) to (90, 1) for a graph with 90 nodes/edges).
""" """
if gidx.number_of_etypes() != 1: if gidx.number_of_etypes() != 1:
raise DGLError("We only support gsddmm on graph with one edge type") raise DGLError("We only support gsddmm on graph with one edge type")
......
from dgl.backend import gspmm, gsddmm from dgl.ops import gspmm, gsddmm
from utils import parametrize_dtype from utils import parametrize_dtype
import dgl import dgl
import random import random
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment