Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
2fa2b453
Unverified
Commit
2fa2b453
authored
Jul 28, 2020
by
Zihao Ye
Committed by
GitHub
Jul 28, 2020
Browse files
[Feature] Support higher order derivative for message passing. (#1877)
* upd * fix typo
parent
2b8eb5be
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
131 additions
and
61 deletions
+131
-61
docs/source/api/python/ops.rst
docs/source/api/python/ops.rst
+2
-0
docs/source/index.rst
docs/source/index.rst
+1
-1
python/dgl/backend/backend.py
python/dgl/backend/backend.py
+4
-4
python/dgl/backend/mxnet/sparse.py
python/dgl/backend/mxnet/sparse.py
+15
-13
python/dgl/backend/pytorch/sparse.py
python/dgl/backend/pytorch/sparse.py
+26
-26
python/dgl/backend/tensorflow/sparse.py
python/dgl/backend/tensorflow/sparse.py
+6
-8
python/dgl/ops/sddmm.py
python/dgl/ops/sddmm.py
+36
-1
python/dgl/ops/spmm.py
python/dgl/ops/spmm.py
+38
-1
python/dgl/sparse.py
python/dgl/sparse.py
+2
-6
tests/compute/test_sparse.py
tests/compute/test_sparse.py
+1
-1
No files found.
docs/source/api/python/ops.rst
View file @
2fa2b453
...
...
@@ -87,6 +87,7 @@ graph.
.. autosummary::
:toctree: ../../generated/
gspmm
u_add_e_sum
u_sub_e_sum
u_mul_e_sum
...
...
@@ -193,6 +194,7 @@ The following is an example showing how GSDDMM works:
.. autosummary::
:toctree: ../../generated/
gsddmm
u_add_v
u_sub_v
u_mul_v
...
...
docs/source/index.rst
View file @
2fa2b453
...
...
@@ -101,7 +101,7 @@ a useful manual for in-depth developers.
api/python/graph
api/python/heterograph
api/python/
backend
api/python/
ops
api/python/readout
api/python/batch_heterograph
api/python/nn
...
...
python/dgl/backend/backend.py
View file @
2fa2b453
...
...
@@ -1377,7 +1377,7 @@ def copy_reduce(reducer, graph, target, in_data, out_size, in_map, out_map):
"""
pass
def
gspmm
(
g
,
op
,
reduce_op
,
lhs_data
,
rhs_data
):
def
gspmm
(
g
idx
,
op
,
reduce_op
,
lhs_data
,
rhs_data
):
r
""" Generalized Sparse Matrix Multiplication interface.
It fuses two steps into one kernel.
(1) Computes messages by :attr:`op` source node and edge features.
...
...
@@ -1395,7 +1395,7 @@ def gspmm(g, op, reduce_op, lhs_data, rhs_data):
Parameters
----------
g :
DGL
HeteroGraph
g
idx
: HeteroGraph
Index
The input graph.
op : str
The binary op's name, could be ``add``, ``sub``, ``mul``, ``div``,
...
...
@@ -1414,7 +1414,7 @@ def gspmm(g, op, reduce_op, lhs_data, rhs_data):
"""
pass
def
gsddmm
(
g
,
op
,
lhs_data
,
rhs_data
,
lhs_target
=
'u'
,
rhs_target
=
'v'
):
def
gsddmm
(
g
idx
,
op
,
lhs_data
,
rhs_data
,
lhs_target
=
'u'
,
rhs_target
=
'v'
):
r
""" Generalized Sampled-Dense-Dense Matrix Multiplication interface.
It computes edge features by :attr:`op` lhs features and rhs features.
...
...
@@ -1428,7 +1428,7 @@ def gsddmm(g, op, lhs_data, rhs_data, lhs_target='u', rhs_target='v'):
Parameters
----------
g :
DGL
HeteroGraph
g
idx
: HeteroGraph
Index
The input graph.
op : str
Binary operator, could be ``add``, ``sub``, ``mul``, ``div``, ``dot``,
...
...
python/dgl/backend/mxnet/sparse.py
View file @
2fa2b453
...
...
@@ -3,7 +3,7 @@ import numpy as np
from
mxnet
import
nd
from
...sparse
import
_gspmm
,
_gsddmm
from
...base
import
dgl_warning
from
.tensor
import
asnumpy
,
copy_to
,
zerocopy_from_numpy
,
context
from
.tensor
import
asnumpy
,
copy_to
,
zerocopy_from_numpy
,
context
,
to_backend_ctx
def
_scatter_nd
(
index
,
src
,
n_rows
):
assert
index
.
shape
==
src
.
shape
...
...
@@ -95,9 +95,9 @@ def _addsub(op, x):
return
-
x
if
op
==
'sub'
else
x
class
GSpMM
(
mx
.
autograd
.
Function
):
def
__init__
(
self
,
g
,
op
,
reduce_op
):
def
__init__
(
self
,
g
idx
,
op
,
reduce_op
):
super
(
GSpMM
,
self
).
__init__
()
self
.
gidx
=
g
.
_graph
self
.
gidx
=
g
idx
self
.
op
=
op
self
.
reduce_op
=
reduce_op
...
...
@@ -154,18 +154,19 @@ class GSpMM(mx.autograd.Function):
self
.
saved_tensors
=
None
return
dX
,
dY
def
gspmm
(
g
,
op
,
reduce_op
,
lhs_data
,
rhs_data
):
func
=
GSpMM
(
g
,
op
,
reduce_op
)
def
gspmm
(
gidx
,
op
,
reduce_op
,
lhs_data
,
rhs_data
):
func
=
GSpMM
(
gidx
,
op
,
reduce_op
)
ctx
=
to_backend_ctx
(
gidx
.
ctx
)
if
lhs_data
is
None
:
lhs_data
=
nd
.
zeros
((
1
,),
ctx
=
g
.
device
)
lhs_data
=
nd
.
zeros
((
1
,),
ctx
=
ctx
)
if
rhs_data
is
None
:
rhs_data
=
nd
.
zeros
((
1
,),
ctx
=
g
.
device
)
rhs_data
=
nd
.
zeros
((
1
,),
ctx
=
ctx
)
return
func
(
lhs_data
,
rhs_data
)
class
GSDDMM
(
mx
.
autograd
.
Function
):
def
__init__
(
self
,
g
,
op
,
lhs_target
,
rhs_target
):
def
__init__
(
self
,
g
idx
,
op
,
lhs_target
,
rhs_target
):
super
(
GSDDMM
,
self
).
__init__
()
self
.
gidx
=
g
.
_graph
self
.
gidx
=
g
idx
self
.
op
=
op
self
.
lhs_target
=
lhs_target
self
.
rhs_target
=
rhs_target
...
...
@@ -225,10 +226,11 @@ class GSDDMM(mx.autograd.Function):
self
.
saved_tensors
=
None
return
dX
,
dY
def
gsddmm
(
g
,
op
,
lhs_data
,
rhs_data
,
lhs_target
=
'u'
,
rhs_target
=
'v'
):
func
=
GSDDMM
(
g
,
op
,
lhs_target
,
rhs_target
)
def
gsddmm
(
gidx
,
op
,
lhs_data
,
rhs_data
,
lhs_target
=
'u'
,
rhs_target
=
'v'
):
func
=
GSDDMM
(
gidx
,
op
,
lhs_target
,
rhs_target
)
ctx
=
to_backend_ctx
(
gidx
.
ctx
)
if
lhs_data
is
None
:
lhs_data
=
nd
.
zeros
((
1
,),
ctx
=
g
.
device
)
lhs_data
=
nd
.
zeros
((
1
,),
ctx
=
ctx
)
if
rhs_data
is
None
:
rhs_data
=
nd
.
zeros
((
1
,),
ctx
=
g
.
device
)
rhs_data
=
nd
.
zeros
((
1
,),
ctx
=
ctx
)
return
func
(
lhs_data
,
rhs_data
)
python/dgl/backend/pytorch/sparse.py
View file @
2fa2b453
...
...
@@ -50,8 +50,7 @@ def _addsub(op, x):
class
GSpMM
(
th
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
g
,
op
,
reduce_op
,
X
,
Y
):
gidx
=
g
.
_graph
def
forward
(
ctx
,
gidx
,
op
,
reduce_op
,
X
,
Y
):
out
,
(
argX
,
argY
)
=
_gspmm
(
gidx
,
op
,
reduce_op
,
X
,
Y
)
ctx
.
backward_cache
=
gidx
,
op
,
reduce_op
ctx
.
save_for_backward
(
X
,
Y
,
argX
,
argY
)
...
...
@@ -65,11 +64,11 @@ class GSpMM(th.autograd.Function):
g_rev
=
gidx
.
reverse
()
if
reduce_op
==
'sum'
:
if
op
in
[
'mul'
,
'div'
]:
dX
=
_
gspmm
(
g_rev
,
'mul'
,
'sum'
,
dZ
,
_muldiv
(
op
,
Y
))
[
0
]
dX
=
gspmm
(
g_rev
,
'mul'
,
'sum'
,
dZ
,
_muldiv
(
op
,
Y
))
elif
op
in
[
'add'
,
'sub'
]:
dX
=
_
gspmm
(
g_rev
,
'copy_lhs'
,
'sum'
,
dZ
,
Y
)
[
0
]
dX
=
gspmm
(
g_rev
,
'copy_lhs'
,
'sum'
,
dZ
,
Y
)
elif
op
==
'copy_lhs'
:
dX
=
_
gspmm
(
g_rev
,
'copy_lhs'
,
'sum'
,
dZ
,
None
)
[
0
]
dX
=
gspmm
(
g_rev
,
'copy_lhs'
,
'sum'
,
dZ
,
None
)
else
:
dX
=
th
.
zeros
((
X
.
shape
[
0
],)
+
dZ
.
shape
[
1
:],
dtype
=
X
.
dtype
,
device
=
X
.
device
)
if
op
in
[
'mul'
,
'div'
]:
...
...
@@ -83,12 +82,12 @@ class GSpMM(th.autograd.Function):
if
op
!=
'copy_lhs'
and
ctx
.
needs_input_grad
[
4
]:
if
reduce_op
==
'sum'
:
if
op
==
'mul'
and
_need_reduce_last_dim
(
X
,
Y
):
dY
=
_
gsddmm
(
gidx
,
'dot'
,
X
,
dZ
)
dY
=
gsddmm
(
gidx
,
'dot'
,
X
,
dZ
)
elif
op
in
[
'mul'
,
'div'
]:
dY
=
_
gsddmm
(
gidx
,
'mul'
,
X
,
dZ
)
dY
=
gsddmm
(
gidx
,
'mul'
,
X
,
dZ
)
if
op
==
'div'
:
dY
=
-
dY
/
(
Y
**
2
)
elif
op
in
[
'add'
,
'sub'
,
'copy_rhs'
]:
dY
=
_
gsddmm
(
gidx
,
'copy_rhs'
,
X
,
_addsub
(
op
,
dZ
))
dY
=
gsddmm
(
gidx
,
'copy_rhs'
,
X
,
_addsub
(
op
,
dZ
))
else
:
dY
=
th
.
zeros
((
Y
.
shape
[
0
],)
+
dZ
.
shape
[
1
:],
dtype
=
Y
.
dtype
,
device
=
Y
.
device
)
if
op
in
[
'mul'
,
'div'
]:
...
...
@@ -104,8 +103,7 @@ class GSpMM(th.autograd.Function):
class
GSDDMM
(
th
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
g
,
op
,
X
,
Y
,
lhs_target
,
rhs_target
):
gidx
=
g
.
_graph
def
forward
(
ctx
,
gidx
,
op
,
X
,
Y
,
lhs_target
,
rhs_target
):
out
=
_gsddmm
(
gidx
,
op
,
X
,
Y
,
lhs_target
,
rhs_target
)
ctx
.
backward_cache
=
gidx
,
op
,
lhs_target
,
rhs_target
ctx
.
save_for_backward
(
X
,
Y
)
...
...
@@ -119,19 +117,19 @@ class GSDDMM(th.autograd.Function):
if
lhs_target
in
[
'u'
,
'v'
]:
_gidx
=
gidx
if
lhs_target
==
'v'
else
gidx
.
reverse
()
if
op
in
[
'add'
,
'sub'
,
'copy_lhs'
]:
dX
=
_
gspmm
(
_gidx
,
'copy_rhs'
,
'sum'
,
None
,
dZ
)
[
0
]
dX
=
gspmm
(
_gidx
,
'copy_rhs'
,
'sum'
,
None
,
dZ
)
else
:
# mul, div, dot
if
rhs_target
==
lhs_target
:
dX
=
_
gspmm
(
_gidx
,
'copy_rhs'
,
'sum'
,
None
,
dZ
)
[
0
]
*
_muldiv
(
op
,
Y
)
dX
=
gspmm
(
_gidx
,
'copy_rhs'
,
'sum'
,
None
,
dZ
)
*
_muldiv
(
op
,
Y
)
elif
rhs_target
==
'e'
:
dX
=
_
gspmm
(
_gidx
,
'copy_rhs'
,
'sum'
,
None
,
dZ
*
_muldiv
(
op
,
Y
))
[
0
]
dX
=
gspmm
(
_gidx
,
'copy_rhs'
,
'sum'
,
None
,
dZ
*
_muldiv
(
op
,
Y
))
else
:
# rhs_target = !lhs_target
dX
=
_
gspmm
(
_gidx
,
'mul'
,
'sum'
,
_muldiv
(
op
,
Y
),
dZ
)
[
0
]
dX
=
gspmm
(
_gidx
,
'mul'
,
'sum'
,
_muldiv
(
op
,
Y
),
dZ
)
else
:
# lhs_target == 'e'
if
op
in
[
'add'
,
'sub'
,
'copy_lhs'
]:
dX
=
dZ
else
:
# mul, div, dot
dX
=
_
gsddmm
(
gidx
,
'mul'
,
dZ
,
_muldiv
(
op
,
Y
),
'e'
,
rhs_target
)
dX
=
gsddmm
(
gidx
,
'mul'
,
dZ
,
_muldiv
(
op
,
Y
),
'e'
,
rhs_target
)
dX
=
_reduce_grad
(
dX
,
X
.
shape
)
else
:
dX
=
None
...
...
@@ -139,29 +137,31 @@ class GSDDMM(th.autograd.Function):
if
rhs_target
in
[
'u'
,
'v'
]:
_gidx
=
gidx
if
rhs_target
==
'v'
else
gidx
.
reverse
()
if
op
in
[
'add'
,
'sub'
,
'copy_rhs'
]:
dY
=
_
gspmm
(
_gidx
,
'copy_rhs'
,
'sum'
,
None
,
_addsub
(
op
,
dZ
))
[
0
]
dY
=
gspmm
(
_gidx
,
'copy_rhs'
,
'sum'
,
None
,
_addsub
(
op
,
dZ
))
else
:
# mul, div, dot
if
lhs_target
==
rhs_target
:
dY
=
_
gspmm
(
_gidx
,
'copy_rhs'
,
'sum'
,
None
,
dZ
)
[
0
]
*
X
dY
=
gspmm
(
_gidx
,
'copy_rhs'
,
'sum'
,
None
,
dZ
)
*
X
elif
lhs_target
==
'e'
:
dY
=
_
gspmm
(
_gidx
,
'copy_rhs'
,
'sum'
,
None
,
dZ
*
X
)
[
0
]
dY
=
gspmm
(
_gidx
,
'copy_rhs'
,
'sum'
,
None
,
dZ
*
X
)
else
:
# rhs_target = !lhs_target
dY
=
_gspmm
(
_gidx
,
'mul'
,
'sum'
,
X
,
dZ
)[
0
]
if
op
==
'div'
:
dY
=
-
dY
/
(
Y
**
2
)
dY
=
gspmm
(
_gidx
,
'mul'
,
'sum'
,
X
,
dZ
)
if
op
==
'div'
:
dY
=
-
dY
/
(
Y
**
2
)
else
:
if
op
in
[
'add'
,
'sub'
,
'copy_rhs'
]:
dY
=
_addsub
(
op
,
dZ
)
else
:
# mul, div, dot
dY
=
_gsddmm
(
gidx
,
'mul'
,
dZ
,
X
,
'e'
,
lhs_target
)
if
op
==
'div'
:
dY
=
-
dY
/
(
Y
**
2
)
dY
=
gsddmm
(
gidx
,
'mul'
,
dZ
,
X
,
'e'
,
lhs_target
)
if
op
==
'div'
:
dY
=
-
dY
/
(
Y
**
2
)
dY
=
_reduce_grad
(
dY
,
Y
.
shape
)
else
:
dY
=
None
return
None
,
None
,
dX
,
dY
,
None
,
None
def
gspmm
(
g
,
op
,
reduce_op
,
lhs_data
,
rhs_data
):
return
GSpMM
.
apply
(
g
,
op
,
reduce_op
,
lhs_data
,
rhs_data
)
def
gspmm
(
g
idx
,
op
,
reduce_op
,
lhs_data
,
rhs_data
):
return
GSpMM
.
apply
(
g
idx
,
op
,
reduce_op
,
lhs_data
,
rhs_data
)
def
gsddmm
(
g
,
op
,
lhs_data
,
rhs_data
,
lhs_target
=
'u'
,
rhs_target
=
'v'
):
return
GSDDMM
.
apply
(
g
,
op
,
lhs_data
,
rhs_data
,
lhs_target
,
rhs_target
)
def
gsddmm
(
g
idx
,
op
,
lhs_data
,
rhs_data
,
lhs_target
=
'u'
,
rhs_target
=
'v'
):
return
GSDDMM
.
apply
(
g
idx
,
op
,
lhs_data
,
rhs_data
,
lhs_target
,
rhs_target
)
python/dgl/backend/tensorflow/sparse.py
View file @
2fa2b453
...
...
@@ -85,8 +85,7 @@ def _muldiv(op, x):
def
_addsub
(
op
,
x
):
return
-
x
if
op
==
'sub'
else
x
def
gspmm_real
(
g
,
op
,
reduce_op
,
X
,
Y
):
gidx
=
g
.
_graph
def
gspmm_real
(
gidx
,
op
,
reduce_op
,
X
,
Y
):
out
,
(
argX
,
argY
)
=
_gspmm
(
gidx
,
op
,
reduce_op
,
X
,
Y
)
def
grad
(
dZ
):
...
...
@@ -136,18 +135,17 @@ def gspmm_real(g, op, reduce_op, X, Y):
return
dX
,
dY
return
out
,
grad
def
gspmm
(
g
,
op
,
reduce_op
,
X
,
Y
):
def
gspmm
(
g
idx
,
op
,
reduce_op
,
X
,
Y
):
@
tf
.
custom_gradient
def
_lambda
(
X
,
Y
):
return
gspmm_real
(
g
,
op
,
reduce_op
,
X
,
Y
)
return
gspmm_real
(
g
idx
,
op
,
reduce_op
,
X
,
Y
)
if
X
is
None
:
X
=
tf
.
zeros
(())
if
Y
is
None
:
Y
=
tf
.
zeros
(())
return
_lambda
(
X
,
Y
)
def
gsddmm_real
(
g
,
op
,
X
,
Y
,
lhs_target
,
rhs_target
):
gidx
=
g
.
_graph
def
gsddmm_real
(
gidx
,
op
,
X
,
Y
,
lhs_target
,
rhs_target
):
out
=
_gsddmm
(
gidx
,
op
,
X
,
Y
,
lhs_target
,
rhs_target
)
def
grad
(
dZ
):
...
...
@@ -196,10 +194,10 @@ def gsddmm_real(g, op, X, Y, lhs_target, rhs_target):
return
dX
,
dY
return
out
,
grad
def
gsddmm
(
g
,
op
,
X
,
Y
,
lhs_target
=
'u'
,
rhs_target
=
'v'
):
def
gsddmm
(
g
idx
,
op
,
X
,
Y
,
lhs_target
=
'u'
,
rhs_target
=
'v'
):
@
tf
.
custom_gradient
def
_lambda
(
X
,
Y
):
return
gsddmm_real
(
g
,
op
,
X
,
Y
,
lhs_target
,
rhs_target
)
return
gsddmm_real
(
g
idx
,
op
,
X
,
Y
,
lhs_target
,
rhs_target
)
if
X
is
None
:
X
=
tf
.
zeros
(())
if
Y
is
None
:
...
...
python/dgl/ops/sddmm.py
View file @
2fa2b453
...
...
@@ -2,10 +2,45 @@
from
itertools
import
product
import
sys
from
..backend
import
gsddmm
from
..backend
import
gsddmm
as
gsddmm_internal
__all__
=
[
'gsddmm'
,
'copy_u'
,
'copy_v'
]
def
gsddmm
(
g
,
op
,
lhs_data
,
rhs_data
,
lhs_target
=
'u'
,
rhs_target
=
'v'
):
r
""" Generalized Sampled-Dense-Dense Matrix Multiplication interface.
It computes edge features by :attr:`op` lhs features and rhs features.
.. math::
x_{e} = \phi(x_{lhs}, x_{rhs}), \forall (u,e,v)\in \mathcal{G}
where :math:`x_{e}` is the returned feature on edges and :math:`x_u`,
:math:`x_v` refers to :attr:`u`, :attr:`v` respectively. :math:`\phi`
is the binary operator :attr:`op`, and :math:`\mathcal{G}` is the graph
we apply gsddmm on: :attr:`g`. $lhs$ and $rhs$ are one of $u,v,e$'s.
Parameters
----------
g : DGLGraph
The input graph.
op : str
Binary operator, could be ``add``, ``sub``, ``mul``, ``div``, ``dot``,
``copy_lhs``, ``copy_rhs``.
lhs_data : tensor or None
The left operand, could be None if it's not required by op.
rhs_data : tensor or None
The right operand, could be None if it's not required by op.
lhs_target: str
Choice of `u`(source), `e`(edge) or `v`(destination) for left operand.
rhs_target: str
Choice of `u`(source), `e`(edge) or `v`(destination) for right operand.
Returns
-------
tensor
The result tensor.
"""
return
gsddmm_internal
(
g
.
_graph
,
op
,
lhs_data
,
rhs_data
,
lhs_target
,
rhs_target
)
def
_gen_sddmm_func
(
lhs_target
,
rhs_target
,
binary_op
):
name
=
"{}_{}_{}"
.
format
(
lhs_target
,
binary_op
,
rhs_target
)
...
...
python/dgl/ops/spmm.py
View file @
2fa2b453
"""dgl spmm operator module."""
import
sys
from
..backend
import
gspmm
from
..backend
import
gspmm
as
gspmm_internal
__all__
=
[
'gspmm'
]
def
gspmm
(
g
,
op
,
reduce_op
,
lhs_data
,
rhs_data
):
r
""" Generalized Sparse Matrix Multiplication interface.
It fuses two steps into one kernel.
(1) Computes messages by :attr:`op` source node and edge features.
(2) Aggregate the messages by :attr:`reduce_op` as the features on destination nodes.
.. math::
x_v = \psi_{(u, v, e)\in \mathcal{G}}(\rho(x_u, x_e))
where :math:`x_v` is the returned feature on destination nodes, and :math`x_u`,
:math:`x_e` refers to :attr:`u`, :attr:`e` respectively. :math:`\rho` means binary
operator :attr:`op` and :math:`\psi` means reduce operator :attr:`reduce_op`,
:math:`\mathcal{G}` is the graph we apply gspmm on: :attr:`g`.
Note that this function does not handle gradients.
Parameters
----------
g : DGLGraph
The input graph.
op : str
The binary op's name, could be ``add``, ``sub``, ``mul``, ``div``,
``copy_lhs``, ``copy_rhs``.
reduce_op : str
Reduce operator, could be ``sum``, ``max``, ``min``.
lhs_data : tensor or None
The left operand, could be None if it's not required by the op.
rhs_data : tensor or None
The right operand, could be None if it's not required by the op.
Returns
-------
tensor
The result tensor.
"""
return
gspmm_internal
(
g
.
_graph
,
op
,
reduce_op
,
lhs_data
,
rhs_data
)
def
_gen_spmm_func
(
binary_op
,
reduce_op
):
name
=
"u_{}_e_{}"
.
format
(
binary_op
,
reduce_op
)
docstring
=
"""Generalized SpMM function.
...
...
python/dgl/sparse.py
View file @
2fa2b453
...
...
@@ -109,9 +109,7 @@ def _gspmm(gidx, op, reduce_op, u, e):
Notes
-----
This function does not handle gradients, and for scalar input features,
we expand its dimension with an additional dimension of length one. (e.g.
(90,) to (90, 1) for a graph with 90 nodes/edges).
This function does not handle gradients.
"""
if
gidx
.
number_of_etypes
()
!=
1
:
raise
DGLError
(
"We only support gspmm on graph with one edge type"
)
...
...
@@ -192,9 +190,7 @@ def _gsddmm(gidx, op, lhs, rhs, lhs_target='u', rhs_target='v'):
Notes
-----
This function does not handle gradients, and for scalar input features,
we expand its dimension with an additional dimension of length one. (e.g.
(90,) to (90, 1) for a graph with 90 nodes/edges).
This function does not handle gradients.
"""
if
gidx
.
number_of_etypes
()
!=
1
:
raise
DGLError
(
"We only support gsddmm on graph with one edge type"
)
...
...
tests/compute/test_sparse.py
View file @
2fa2b453
from
dgl.
backend
import
gspmm
,
gsddmm
from
dgl.
ops
import
gspmm
,
gsddmm
from
utils
import
parametrize_dtype
import
dgl
import
random
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment