Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
2fa2b453
Unverified
Commit
2fa2b453
authored
Jul 28, 2020
by
Zihao Ye
Committed by
GitHub
Jul 28, 2020
Browse files
[Feature] Support higher order derivative for message passing. (#1877)
* upd * fix typo
parent
2b8eb5be
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
131 additions
and
61 deletions
+131
-61
docs/source/api/python/ops.rst
docs/source/api/python/ops.rst
+2
-0
docs/source/index.rst
docs/source/index.rst
+1
-1
python/dgl/backend/backend.py
python/dgl/backend/backend.py
+4
-4
python/dgl/backend/mxnet/sparse.py
python/dgl/backend/mxnet/sparse.py
+15
-13
python/dgl/backend/pytorch/sparse.py
python/dgl/backend/pytorch/sparse.py
+26
-26
python/dgl/backend/tensorflow/sparse.py
python/dgl/backend/tensorflow/sparse.py
+6
-8
python/dgl/ops/sddmm.py
python/dgl/ops/sddmm.py
+36
-1
python/dgl/ops/spmm.py
python/dgl/ops/spmm.py
+38
-1
python/dgl/sparse.py
python/dgl/sparse.py
+2
-6
tests/compute/test_sparse.py
tests/compute/test_sparse.py
+1
-1
No files found.
docs/source/api/python/ops.rst
View file @
2fa2b453
...
@@ -87,6 +87,7 @@ graph.
...
@@ -87,6 +87,7 @@ graph.
.. autosummary::
.. autosummary::
:toctree: ../../generated/
:toctree: ../../generated/
gspmm
u_add_e_sum
u_add_e_sum
u_sub_e_sum
u_sub_e_sum
u_mul_e_sum
u_mul_e_sum
...
@@ -193,6 +194,7 @@ The following is an example showing how GSDDMM works:
...
@@ -193,6 +194,7 @@ The following is an example showing how GSDDMM works:
.. autosummary::
.. autosummary::
:toctree: ../../generated/
:toctree: ../../generated/
gsddmm
u_add_v
u_add_v
u_sub_v
u_sub_v
u_mul_v
u_mul_v
...
...
docs/source/index.rst
View file @
2fa2b453
...
@@ -101,7 +101,7 @@ a useful manual for in-depth developers.
...
@@ -101,7 +101,7 @@ a useful manual for in-depth developers.
api/python/graph
api/python/graph
api/python/heterograph
api/python/heterograph
api/python/
backend
api/python/
ops
api/python/readout
api/python/readout
api/python/batch_heterograph
api/python/batch_heterograph
api/python/nn
api/python/nn
...
...
python/dgl/backend/backend.py
View file @
2fa2b453
...
@@ -1377,7 +1377,7 @@ def copy_reduce(reducer, graph, target, in_data, out_size, in_map, out_map):
...
@@ -1377,7 +1377,7 @@ def copy_reduce(reducer, graph, target, in_data, out_size, in_map, out_map):
"""
"""
pass
pass
def
gspmm
(
g
,
op
,
reduce_op
,
lhs_data
,
rhs_data
):
def
gspmm
(
g
idx
,
op
,
reduce_op
,
lhs_data
,
rhs_data
):
r
""" Generalized Sparse Matrix Multiplication interface.
r
""" Generalized Sparse Matrix Multiplication interface.
It fuses two steps into one kernel.
It fuses two steps into one kernel.
(1) Computes messages by :attr:`op` source node and edge features.
(1) Computes messages by :attr:`op` source node and edge features.
...
@@ -1395,7 +1395,7 @@ def gspmm(g, op, reduce_op, lhs_data, rhs_data):
...
@@ -1395,7 +1395,7 @@ def gspmm(g, op, reduce_op, lhs_data, rhs_data):
Parameters
Parameters
----------
----------
g :
DGL
HeteroGraph
g
idx
: HeteroGraph
Index
The input graph.
The input graph.
op : str
op : str
The binary op's name, could be ``add``, ``sub``, ``mul``, ``div``,
The binary op's name, could be ``add``, ``sub``, ``mul``, ``div``,
...
@@ -1414,7 +1414,7 @@ def gspmm(g, op, reduce_op, lhs_data, rhs_data):
...
@@ -1414,7 +1414,7 @@ def gspmm(g, op, reduce_op, lhs_data, rhs_data):
"""
"""
pass
pass
def
gsddmm
(
g
,
op
,
lhs_data
,
rhs_data
,
lhs_target
=
'u'
,
rhs_target
=
'v'
):
def
gsddmm
(
g
idx
,
op
,
lhs_data
,
rhs_data
,
lhs_target
=
'u'
,
rhs_target
=
'v'
):
r
""" Generalized Sampled-Dense-Dense Matrix Multiplication interface.
r
""" Generalized Sampled-Dense-Dense Matrix Multiplication interface.
It computes edge features by :attr:`op` lhs features and rhs features.
It computes edge features by :attr:`op` lhs features and rhs features.
...
@@ -1428,7 +1428,7 @@ def gsddmm(g, op, lhs_data, rhs_data, lhs_target='u', rhs_target='v'):
...
@@ -1428,7 +1428,7 @@ def gsddmm(g, op, lhs_data, rhs_data, lhs_target='u', rhs_target='v'):
Parameters
Parameters
----------
----------
g :
DGL
HeteroGraph
g
idx
: HeteroGraph
Index
The input graph.
The input graph.
op : str
op : str
Binary operator, could be ``add``, ``sub``, ``mul``, ``div``, ``dot``,
Binary operator, could be ``add``, ``sub``, ``mul``, ``div``, ``dot``,
...
...
python/dgl/backend/mxnet/sparse.py
View file @
2fa2b453
...
@@ -3,7 +3,7 @@ import numpy as np
...
@@ -3,7 +3,7 @@ import numpy as np
from
mxnet
import
nd
from
mxnet
import
nd
from
...sparse
import
_gspmm
,
_gsddmm
from
...sparse
import
_gspmm
,
_gsddmm
from
...base
import
dgl_warning
from
...base
import
dgl_warning
from
.tensor
import
asnumpy
,
copy_to
,
zerocopy_from_numpy
,
context
from
.tensor
import
asnumpy
,
copy_to
,
zerocopy_from_numpy
,
context
,
to_backend_ctx
def
_scatter_nd
(
index
,
src
,
n_rows
):
def
_scatter_nd
(
index
,
src
,
n_rows
):
assert
index
.
shape
==
src
.
shape
assert
index
.
shape
==
src
.
shape
...
@@ -95,9 +95,9 @@ def _addsub(op, x):
...
@@ -95,9 +95,9 @@ def _addsub(op, x):
return
-
x
if
op
==
'sub'
else
x
return
-
x
if
op
==
'sub'
else
x
class
GSpMM
(
mx
.
autograd
.
Function
):
class
GSpMM
(
mx
.
autograd
.
Function
):
def
__init__
(
self
,
g
,
op
,
reduce_op
):
def
__init__
(
self
,
g
idx
,
op
,
reduce_op
):
super
(
GSpMM
,
self
).
__init__
()
super
(
GSpMM
,
self
).
__init__
()
self
.
gidx
=
g
.
_graph
self
.
gidx
=
g
idx
self
.
op
=
op
self
.
op
=
op
self
.
reduce_op
=
reduce_op
self
.
reduce_op
=
reduce_op
...
@@ -154,18 +154,19 @@ class GSpMM(mx.autograd.Function):
...
@@ -154,18 +154,19 @@ class GSpMM(mx.autograd.Function):
self
.
saved_tensors
=
None
self
.
saved_tensors
=
None
return
dX
,
dY
return
dX
,
dY
def
gspmm
(
g
,
op
,
reduce_op
,
lhs_data
,
rhs_data
):
def
gspmm
(
gidx
,
op
,
reduce_op
,
lhs_data
,
rhs_data
):
func
=
GSpMM
(
g
,
op
,
reduce_op
)
func
=
GSpMM
(
gidx
,
op
,
reduce_op
)
ctx
=
to_backend_ctx
(
gidx
.
ctx
)
if
lhs_data
is
None
:
if
lhs_data
is
None
:
lhs_data
=
nd
.
zeros
((
1
,),
ctx
=
g
.
device
)
lhs_data
=
nd
.
zeros
((
1
,),
ctx
=
ctx
)
if
rhs_data
is
None
:
if
rhs_data
is
None
:
rhs_data
=
nd
.
zeros
((
1
,),
ctx
=
g
.
device
)
rhs_data
=
nd
.
zeros
((
1
,),
ctx
=
ctx
)
return
func
(
lhs_data
,
rhs_data
)
return
func
(
lhs_data
,
rhs_data
)
class
GSDDMM
(
mx
.
autograd
.
Function
):
class
GSDDMM
(
mx
.
autograd
.
Function
):
def
__init__
(
self
,
g
,
op
,
lhs_target
,
rhs_target
):
def
__init__
(
self
,
g
idx
,
op
,
lhs_target
,
rhs_target
):
super
(
GSDDMM
,
self
).
__init__
()
super
(
GSDDMM
,
self
).
__init__
()
self
.
gidx
=
g
.
_graph
self
.
gidx
=
g
idx
self
.
op
=
op
self
.
op
=
op
self
.
lhs_target
=
lhs_target
self
.
lhs_target
=
lhs_target
self
.
rhs_target
=
rhs_target
self
.
rhs_target
=
rhs_target
...
@@ -225,10 +226,11 @@ class GSDDMM(mx.autograd.Function):
...
@@ -225,10 +226,11 @@ class GSDDMM(mx.autograd.Function):
self
.
saved_tensors
=
None
self
.
saved_tensors
=
None
return
dX
,
dY
return
dX
,
dY
def
gsddmm
(
g
,
op
,
lhs_data
,
rhs_data
,
lhs_target
=
'u'
,
rhs_target
=
'v'
):
def
gsddmm
(
gidx
,
op
,
lhs_data
,
rhs_data
,
lhs_target
=
'u'
,
rhs_target
=
'v'
):
func
=
GSDDMM
(
g
,
op
,
lhs_target
,
rhs_target
)
func
=
GSDDMM
(
gidx
,
op
,
lhs_target
,
rhs_target
)
ctx
=
to_backend_ctx
(
gidx
.
ctx
)
if
lhs_data
is
None
:
if
lhs_data
is
None
:
lhs_data
=
nd
.
zeros
((
1
,),
ctx
=
g
.
device
)
lhs_data
=
nd
.
zeros
((
1
,),
ctx
=
ctx
)
if
rhs_data
is
None
:
if
rhs_data
is
None
:
rhs_data
=
nd
.
zeros
((
1
,),
ctx
=
g
.
device
)
rhs_data
=
nd
.
zeros
((
1
,),
ctx
=
ctx
)
return
func
(
lhs_data
,
rhs_data
)
return
func
(
lhs_data
,
rhs_data
)
python/dgl/backend/pytorch/sparse.py
View file @
2fa2b453
...
@@ -50,8 +50,7 @@ def _addsub(op, x):
...
@@ -50,8 +50,7 @@ def _addsub(op, x):
class
GSpMM
(
th
.
autograd
.
Function
):
class
GSpMM
(
th
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
g
,
op
,
reduce_op
,
X
,
Y
):
def
forward
(
ctx
,
gidx
,
op
,
reduce_op
,
X
,
Y
):
gidx
=
g
.
_graph
out
,
(
argX
,
argY
)
=
_gspmm
(
gidx
,
op
,
reduce_op
,
X
,
Y
)
out
,
(
argX
,
argY
)
=
_gspmm
(
gidx
,
op
,
reduce_op
,
X
,
Y
)
ctx
.
backward_cache
=
gidx
,
op
,
reduce_op
ctx
.
backward_cache
=
gidx
,
op
,
reduce_op
ctx
.
save_for_backward
(
X
,
Y
,
argX
,
argY
)
ctx
.
save_for_backward
(
X
,
Y
,
argX
,
argY
)
...
@@ -65,11 +64,11 @@ class GSpMM(th.autograd.Function):
...
@@ -65,11 +64,11 @@ class GSpMM(th.autograd.Function):
g_rev
=
gidx
.
reverse
()
g_rev
=
gidx
.
reverse
()
if
reduce_op
==
'sum'
:
if
reduce_op
==
'sum'
:
if
op
in
[
'mul'
,
'div'
]:
if
op
in
[
'mul'
,
'div'
]:
dX
=
_
gspmm
(
g_rev
,
'mul'
,
'sum'
,
dZ
,
_muldiv
(
op
,
Y
))
[
0
]
dX
=
gspmm
(
g_rev
,
'mul'
,
'sum'
,
dZ
,
_muldiv
(
op
,
Y
))
elif
op
in
[
'add'
,
'sub'
]:
elif
op
in
[
'add'
,
'sub'
]:
dX
=
_
gspmm
(
g_rev
,
'copy_lhs'
,
'sum'
,
dZ
,
Y
)
[
0
]
dX
=
gspmm
(
g_rev
,
'copy_lhs'
,
'sum'
,
dZ
,
Y
)
elif
op
==
'copy_lhs'
:
elif
op
==
'copy_lhs'
:
dX
=
_
gspmm
(
g_rev
,
'copy_lhs'
,
'sum'
,
dZ
,
None
)
[
0
]
dX
=
gspmm
(
g_rev
,
'copy_lhs'
,
'sum'
,
dZ
,
None
)
else
:
else
:
dX
=
th
.
zeros
((
X
.
shape
[
0
],)
+
dZ
.
shape
[
1
:],
dtype
=
X
.
dtype
,
device
=
X
.
device
)
dX
=
th
.
zeros
((
X
.
shape
[
0
],)
+
dZ
.
shape
[
1
:],
dtype
=
X
.
dtype
,
device
=
X
.
device
)
if
op
in
[
'mul'
,
'div'
]:
if
op
in
[
'mul'
,
'div'
]:
...
@@ -83,12 +82,12 @@ class GSpMM(th.autograd.Function):
...
@@ -83,12 +82,12 @@ class GSpMM(th.autograd.Function):
if
op
!=
'copy_lhs'
and
ctx
.
needs_input_grad
[
4
]:
if
op
!=
'copy_lhs'
and
ctx
.
needs_input_grad
[
4
]:
if
reduce_op
==
'sum'
:
if
reduce_op
==
'sum'
:
if
op
==
'mul'
and
_need_reduce_last_dim
(
X
,
Y
):
if
op
==
'mul'
and
_need_reduce_last_dim
(
X
,
Y
):
dY
=
_
gsddmm
(
gidx
,
'dot'
,
X
,
dZ
)
dY
=
gsddmm
(
gidx
,
'dot'
,
X
,
dZ
)
elif
op
in
[
'mul'
,
'div'
]:
elif
op
in
[
'mul'
,
'div'
]:
dY
=
_
gsddmm
(
gidx
,
'mul'
,
X
,
dZ
)
dY
=
gsddmm
(
gidx
,
'mul'
,
X
,
dZ
)
if
op
==
'div'
:
dY
=
-
dY
/
(
Y
**
2
)
if
op
==
'div'
:
dY
=
-
dY
/
(
Y
**
2
)
elif
op
in
[
'add'
,
'sub'
,
'copy_rhs'
]:
elif
op
in
[
'add'
,
'sub'
,
'copy_rhs'
]:
dY
=
_
gsddmm
(
gidx
,
'copy_rhs'
,
X
,
_addsub
(
op
,
dZ
))
dY
=
gsddmm
(
gidx
,
'copy_rhs'
,
X
,
_addsub
(
op
,
dZ
))
else
:
else
:
dY
=
th
.
zeros
((
Y
.
shape
[
0
],)
+
dZ
.
shape
[
1
:],
dtype
=
Y
.
dtype
,
device
=
Y
.
device
)
dY
=
th
.
zeros
((
Y
.
shape
[
0
],)
+
dZ
.
shape
[
1
:],
dtype
=
Y
.
dtype
,
device
=
Y
.
device
)
if
op
in
[
'mul'
,
'div'
]:
if
op
in
[
'mul'
,
'div'
]:
...
@@ -104,8 +103,7 @@ class GSpMM(th.autograd.Function):
...
@@ -104,8 +103,7 @@ class GSpMM(th.autograd.Function):
class
GSDDMM
(
th
.
autograd
.
Function
):
class
GSDDMM
(
th
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
g
,
op
,
X
,
Y
,
lhs_target
,
rhs_target
):
def
forward
(
ctx
,
gidx
,
op
,
X
,
Y
,
lhs_target
,
rhs_target
):
gidx
=
g
.
_graph
out
=
_gsddmm
(
gidx
,
op
,
X
,
Y
,
lhs_target
,
rhs_target
)
out
=
_gsddmm
(
gidx
,
op
,
X
,
Y
,
lhs_target
,
rhs_target
)
ctx
.
backward_cache
=
gidx
,
op
,
lhs_target
,
rhs_target
ctx
.
backward_cache
=
gidx
,
op
,
lhs_target
,
rhs_target
ctx
.
save_for_backward
(
X
,
Y
)
ctx
.
save_for_backward
(
X
,
Y
)
...
@@ -119,19 +117,19 @@ class GSDDMM(th.autograd.Function):
...
@@ -119,19 +117,19 @@ class GSDDMM(th.autograd.Function):
if
lhs_target
in
[
'u'
,
'v'
]:
if
lhs_target
in
[
'u'
,
'v'
]:
_gidx
=
gidx
if
lhs_target
==
'v'
else
gidx
.
reverse
()
_gidx
=
gidx
if
lhs_target
==
'v'
else
gidx
.
reverse
()
if
op
in
[
'add'
,
'sub'
,
'copy_lhs'
]:
if
op
in
[
'add'
,
'sub'
,
'copy_lhs'
]:
dX
=
_
gspmm
(
_gidx
,
'copy_rhs'
,
'sum'
,
None
,
dZ
)
[
0
]
dX
=
gspmm
(
_gidx
,
'copy_rhs'
,
'sum'
,
None
,
dZ
)
else
:
# mul, div, dot
else
:
# mul, div, dot
if
rhs_target
==
lhs_target
:
if
rhs_target
==
lhs_target
:
dX
=
_
gspmm
(
_gidx
,
'copy_rhs'
,
'sum'
,
None
,
dZ
)
[
0
]
*
_muldiv
(
op
,
Y
)
dX
=
gspmm
(
_gidx
,
'copy_rhs'
,
'sum'
,
None
,
dZ
)
*
_muldiv
(
op
,
Y
)
elif
rhs_target
==
'e'
:
elif
rhs_target
==
'e'
:
dX
=
_
gspmm
(
_gidx
,
'copy_rhs'
,
'sum'
,
None
,
dZ
*
_muldiv
(
op
,
Y
))
[
0
]
dX
=
gspmm
(
_gidx
,
'copy_rhs'
,
'sum'
,
None
,
dZ
*
_muldiv
(
op
,
Y
))
else
:
# rhs_target = !lhs_target
else
:
# rhs_target = !lhs_target
dX
=
_
gspmm
(
_gidx
,
'mul'
,
'sum'
,
_muldiv
(
op
,
Y
),
dZ
)
[
0
]
dX
=
gspmm
(
_gidx
,
'mul'
,
'sum'
,
_muldiv
(
op
,
Y
),
dZ
)
else
:
# lhs_target == 'e'
else
:
# lhs_target == 'e'
if
op
in
[
'add'
,
'sub'
,
'copy_lhs'
]:
if
op
in
[
'add'
,
'sub'
,
'copy_lhs'
]:
dX
=
dZ
dX
=
dZ
else
:
# mul, div, dot
else
:
# mul, div, dot
dX
=
_
gsddmm
(
gidx
,
'mul'
,
dZ
,
_muldiv
(
op
,
Y
),
'e'
,
rhs_target
)
dX
=
gsddmm
(
gidx
,
'mul'
,
dZ
,
_muldiv
(
op
,
Y
),
'e'
,
rhs_target
)
dX
=
_reduce_grad
(
dX
,
X
.
shape
)
dX
=
_reduce_grad
(
dX
,
X
.
shape
)
else
:
else
:
dX
=
None
dX
=
None
...
@@ -139,29 +137,31 @@ class GSDDMM(th.autograd.Function):
...
@@ -139,29 +137,31 @@ class GSDDMM(th.autograd.Function):
if
rhs_target
in
[
'u'
,
'v'
]:
if
rhs_target
in
[
'u'
,
'v'
]:
_gidx
=
gidx
if
rhs_target
==
'v'
else
gidx
.
reverse
()
_gidx
=
gidx
if
rhs_target
==
'v'
else
gidx
.
reverse
()
if
op
in
[
'add'
,
'sub'
,
'copy_rhs'
]:
if
op
in
[
'add'
,
'sub'
,
'copy_rhs'
]:
dY
=
_
gspmm
(
_gidx
,
'copy_rhs'
,
'sum'
,
None
,
_addsub
(
op
,
dZ
))
[
0
]
dY
=
gspmm
(
_gidx
,
'copy_rhs'
,
'sum'
,
None
,
_addsub
(
op
,
dZ
))
else
:
# mul, div, dot
else
:
# mul, div, dot
if
lhs_target
==
rhs_target
:
if
lhs_target
==
rhs_target
:
dY
=
_
gspmm
(
_gidx
,
'copy_rhs'
,
'sum'
,
None
,
dZ
)
[
0
]
*
X
dY
=
gspmm
(
_gidx
,
'copy_rhs'
,
'sum'
,
None
,
dZ
)
*
X
elif
lhs_target
==
'e'
:
elif
lhs_target
==
'e'
:
dY
=
_
gspmm
(
_gidx
,
'copy_rhs'
,
'sum'
,
None
,
dZ
*
X
)
[
0
]
dY
=
gspmm
(
_gidx
,
'copy_rhs'
,
'sum'
,
None
,
dZ
*
X
)
else
:
# rhs_target = !lhs_target
else
:
# rhs_target = !lhs_target
dY
=
_gspmm
(
_gidx
,
'mul'
,
'sum'
,
X
,
dZ
)[
0
]
dY
=
gspmm
(
_gidx
,
'mul'
,
'sum'
,
X
,
dZ
)
if
op
==
'div'
:
dY
=
-
dY
/
(
Y
**
2
)
if
op
==
'div'
:
dY
=
-
dY
/
(
Y
**
2
)
else
:
else
:
if
op
in
[
'add'
,
'sub'
,
'copy_rhs'
]:
if
op
in
[
'add'
,
'sub'
,
'copy_rhs'
]:
dY
=
_addsub
(
op
,
dZ
)
dY
=
_addsub
(
op
,
dZ
)
else
:
# mul, div, dot
else
:
# mul, div, dot
dY
=
_gsddmm
(
gidx
,
'mul'
,
dZ
,
X
,
'e'
,
lhs_target
)
dY
=
gsddmm
(
gidx
,
'mul'
,
dZ
,
X
,
'e'
,
lhs_target
)
if
op
==
'div'
:
dY
=
-
dY
/
(
Y
**
2
)
if
op
==
'div'
:
dY
=
-
dY
/
(
Y
**
2
)
dY
=
_reduce_grad
(
dY
,
Y
.
shape
)
dY
=
_reduce_grad
(
dY
,
Y
.
shape
)
else
:
else
:
dY
=
None
dY
=
None
return
None
,
None
,
dX
,
dY
,
None
,
None
return
None
,
None
,
dX
,
dY
,
None
,
None
def
gspmm
(
g
,
op
,
reduce_op
,
lhs_data
,
rhs_data
):
def
gspmm
(
g
idx
,
op
,
reduce_op
,
lhs_data
,
rhs_data
):
return
GSpMM
.
apply
(
g
,
op
,
reduce_op
,
lhs_data
,
rhs_data
)
return
GSpMM
.
apply
(
g
idx
,
op
,
reduce_op
,
lhs_data
,
rhs_data
)
def
gsddmm
(
g
,
op
,
lhs_data
,
rhs_data
,
lhs_target
=
'u'
,
rhs_target
=
'v'
):
def
gsddmm
(
g
idx
,
op
,
lhs_data
,
rhs_data
,
lhs_target
=
'u'
,
rhs_target
=
'v'
):
return
GSDDMM
.
apply
(
g
,
op
,
lhs_data
,
rhs_data
,
lhs_target
,
rhs_target
)
return
GSDDMM
.
apply
(
g
idx
,
op
,
lhs_data
,
rhs_data
,
lhs_target
,
rhs_target
)
python/dgl/backend/tensorflow/sparse.py
View file @
2fa2b453
...
@@ -85,8 +85,7 @@ def _muldiv(op, x):
...
@@ -85,8 +85,7 @@ def _muldiv(op, x):
def
_addsub
(
op
,
x
):
def
_addsub
(
op
,
x
):
return
-
x
if
op
==
'sub'
else
x
return
-
x
if
op
==
'sub'
else
x
def
gspmm_real
(
g
,
op
,
reduce_op
,
X
,
Y
):
def
gspmm_real
(
gidx
,
op
,
reduce_op
,
X
,
Y
):
gidx
=
g
.
_graph
out
,
(
argX
,
argY
)
=
_gspmm
(
gidx
,
op
,
reduce_op
,
X
,
Y
)
out
,
(
argX
,
argY
)
=
_gspmm
(
gidx
,
op
,
reduce_op
,
X
,
Y
)
def
grad
(
dZ
):
def
grad
(
dZ
):
...
@@ -136,18 +135,17 @@ def gspmm_real(g, op, reduce_op, X, Y):
...
@@ -136,18 +135,17 @@ def gspmm_real(g, op, reduce_op, X, Y):
return
dX
,
dY
return
dX
,
dY
return
out
,
grad
return
out
,
grad
def
gspmm
(
g
,
op
,
reduce_op
,
X
,
Y
):
def
gspmm
(
g
idx
,
op
,
reduce_op
,
X
,
Y
):
@
tf
.
custom_gradient
@
tf
.
custom_gradient
def
_lambda
(
X
,
Y
):
def
_lambda
(
X
,
Y
):
return
gspmm_real
(
g
,
op
,
reduce_op
,
X
,
Y
)
return
gspmm_real
(
g
idx
,
op
,
reduce_op
,
X
,
Y
)
if
X
is
None
:
if
X
is
None
:
X
=
tf
.
zeros
(())
X
=
tf
.
zeros
(())
if
Y
is
None
:
if
Y
is
None
:
Y
=
tf
.
zeros
(())
Y
=
tf
.
zeros
(())
return
_lambda
(
X
,
Y
)
return
_lambda
(
X
,
Y
)
def
gsddmm_real
(
g
,
op
,
X
,
Y
,
lhs_target
,
rhs_target
):
def
gsddmm_real
(
gidx
,
op
,
X
,
Y
,
lhs_target
,
rhs_target
):
gidx
=
g
.
_graph
out
=
_gsddmm
(
gidx
,
op
,
X
,
Y
,
lhs_target
,
rhs_target
)
out
=
_gsddmm
(
gidx
,
op
,
X
,
Y
,
lhs_target
,
rhs_target
)
def
grad
(
dZ
):
def
grad
(
dZ
):
...
@@ -196,10 +194,10 @@ def gsddmm_real(g, op, X, Y, lhs_target, rhs_target):
...
@@ -196,10 +194,10 @@ def gsddmm_real(g, op, X, Y, lhs_target, rhs_target):
return
dX
,
dY
return
dX
,
dY
return
out
,
grad
return
out
,
grad
def
gsddmm
(
g
,
op
,
X
,
Y
,
lhs_target
=
'u'
,
rhs_target
=
'v'
):
def
gsddmm
(
g
idx
,
op
,
X
,
Y
,
lhs_target
=
'u'
,
rhs_target
=
'v'
):
@
tf
.
custom_gradient
@
tf
.
custom_gradient
def
_lambda
(
X
,
Y
):
def
_lambda
(
X
,
Y
):
return
gsddmm_real
(
g
,
op
,
X
,
Y
,
lhs_target
,
rhs_target
)
return
gsddmm_real
(
g
idx
,
op
,
X
,
Y
,
lhs_target
,
rhs_target
)
if
X
is
None
:
if
X
is
None
:
X
=
tf
.
zeros
(())
X
=
tf
.
zeros
(())
if
Y
is
None
:
if
Y
is
None
:
...
...
python/dgl/ops/sddmm.py
View file @
2fa2b453
...
@@ -2,10 +2,45 @@
...
@@ -2,10 +2,45 @@
from
itertools
import
product
from
itertools
import
product
import
sys
import
sys
from
..backend
import
gsddmm
from
..backend
import
gsddmm
as
gsddmm_internal
__all__
=
[
'gsddmm'
,
'copy_u'
,
'copy_v'
]
__all__
=
[
'gsddmm'
,
'copy_u'
,
'copy_v'
]
def
gsddmm
(
g
,
op
,
lhs_data
,
rhs_data
,
lhs_target
=
'u'
,
rhs_target
=
'v'
):
r
""" Generalized Sampled-Dense-Dense Matrix Multiplication interface.
It computes edge features by :attr:`op` lhs features and rhs features.
.. math::
x_{e} = \phi(x_{lhs}, x_{rhs}), \forall (u,e,v)\in \mathcal{G}
where :math:`x_{e}` is the returned feature on edges and :math:`x_u`,
:math:`x_v` refers to :attr:`u`, :attr:`v` respectively. :math:`\phi`
is the binary operator :attr:`op`, and :math:`\mathcal{G}` is the graph
we apply gsddmm on: :attr:`g`. $lhs$ and $rhs$ are one of $u,v,e$'s.
Parameters
----------
g : DGLGraph
The input graph.
op : str
Binary operator, could be ``add``, ``sub``, ``mul``, ``div``, ``dot``,
``copy_lhs``, ``copy_rhs``.
lhs_data : tensor or None
The left operand, could be None if it's not required by op.
rhs_data : tensor or None
The right operand, could be None if it's not required by op.
lhs_target: str
Choice of `u`(source), `e`(edge) or `v`(destination) for left operand.
rhs_target: str
Choice of `u`(source), `e`(edge) or `v`(destination) for right operand.
Returns
-------
tensor
The result tensor.
"""
return
gsddmm_internal
(
g
.
_graph
,
op
,
lhs_data
,
rhs_data
,
lhs_target
,
rhs_target
)
def
_gen_sddmm_func
(
lhs_target
,
rhs_target
,
binary_op
):
def
_gen_sddmm_func
(
lhs_target
,
rhs_target
,
binary_op
):
name
=
"{}_{}_{}"
.
format
(
lhs_target
,
binary_op
,
rhs_target
)
name
=
"{}_{}_{}"
.
format
(
lhs_target
,
binary_op
,
rhs_target
)
...
...
python/dgl/ops/spmm.py
View file @
2fa2b453
"""dgl spmm operator module."""
"""dgl spmm operator module."""
import
sys
import
sys
from
..backend
import
gspmm
from
..backend
import
gspmm
as
gspmm_internal
__all__
=
[
'gspmm'
]
__all__
=
[
'gspmm'
]
def
gspmm
(
g
,
op
,
reduce_op
,
lhs_data
,
rhs_data
):
r
""" Generalized Sparse Matrix Multiplication interface.
It fuses two steps into one kernel.
(1) Computes messages by :attr:`op` source node and edge features.
(2) Aggregate the messages by :attr:`reduce_op` as the features on destination nodes.
.. math::
x_v = \psi_{(u, v, e)\in \mathcal{G}}(\rho(x_u, x_e))
where :math:`x_v` is the returned feature on destination nodes, and :math`x_u`,
:math:`x_e` refers to :attr:`u`, :attr:`e` respectively. :math:`\rho` means binary
operator :attr:`op` and :math:`\psi` means reduce operator :attr:`reduce_op`,
:math:`\mathcal{G}` is the graph we apply gspmm on: :attr:`g`.
Note that this function does not handle gradients.
Parameters
----------
g : DGLGraph
The input graph.
op : str
The binary op's name, could be ``add``, ``sub``, ``mul``, ``div``,
``copy_lhs``, ``copy_rhs``.
reduce_op : str
Reduce operator, could be ``sum``, ``max``, ``min``.
lhs_data : tensor or None
The left operand, could be None if it's not required by the op.
rhs_data : tensor or None
The right operand, could be None if it's not required by the op.
Returns
-------
tensor
The result tensor.
"""
return
gspmm_internal
(
g
.
_graph
,
op
,
reduce_op
,
lhs_data
,
rhs_data
)
def
_gen_spmm_func
(
binary_op
,
reduce_op
):
def
_gen_spmm_func
(
binary_op
,
reduce_op
):
name
=
"u_{}_e_{}"
.
format
(
binary_op
,
reduce_op
)
name
=
"u_{}_e_{}"
.
format
(
binary_op
,
reduce_op
)
docstring
=
"""Generalized SpMM function.
docstring
=
"""Generalized SpMM function.
...
...
python/dgl/sparse.py
View file @
2fa2b453
...
@@ -109,9 +109,7 @@ def _gspmm(gidx, op, reduce_op, u, e):
...
@@ -109,9 +109,7 @@ def _gspmm(gidx, op, reduce_op, u, e):
Notes
Notes
-----
-----
This function does not handle gradients, and for scalar input features,
This function does not handle gradients.
we expand its dimension with an additional dimension of length one. (e.g.
(90,) to (90, 1) for a graph with 90 nodes/edges).
"""
"""
if
gidx
.
number_of_etypes
()
!=
1
:
if
gidx
.
number_of_etypes
()
!=
1
:
raise
DGLError
(
"We only support gspmm on graph with one edge type"
)
raise
DGLError
(
"We only support gspmm on graph with one edge type"
)
...
@@ -192,9 +190,7 @@ def _gsddmm(gidx, op, lhs, rhs, lhs_target='u', rhs_target='v'):
...
@@ -192,9 +190,7 @@ def _gsddmm(gidx, op, lhs, rhs, lhs_target='u', rhs_target='v'):
Notes
Notes
-----
-----
This function does not handle gradients, and for scalar input features,
This function does not handle gradients.
we expand its dimension with an additional dimension of length one. (e.g.
(90,) to (90, 1) for a graph with 90 nodes/edges).
"""
"""
if
gidx
.
number_of_etypes
()
!=
1
:
if
gidx
.
number_of_etypes
()
!=
1
:
raise
DGLError
(
"We only support gsddmm on graph with one edge type"
)
raise
DGLError
(
"We only support gsddmm on graph with one edge type"
)
...
...
tests/compute/test_sparse.py
View file @
2fa2b453
from
dgl.
backend
import
gspmm
,
gsddmm
from
dgl.
ops
import
gspmm
,
gsddmm
from
utils
import
parametrize_dtype
from
utils
import
parametrize_dtype
import
dgl
import
dgl
import
random
import
random
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment