Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
878acdb0
Unverified
Commit
878acdb0
authored
Jan 28, 2021
by
Minjie Wang
Committed by
GitHub
Jan 28, 2021
Browse files
Revert "Refactor code for retaining formats in message-passing. (#2570)" (#2583)
This reverts commit
a613ad88
.
parent
7bab1365
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
39 additions
and
47 deletions
+39
-47
python/dgl/backend/mxnet/sparse.py
python/dgl/backend/mxnet/sparse.py
+5
-5
python/dgl/backend/pytorch/sparse.py
python/dgl/backend/pytorch/sparse.py
+29
-1
python/dgl/backend/tensorflow/sparse.py
python/dgl/backend/tensorflow/sparse.py
+5
-5
python/dgl/heterograph.py
python/dgl/heterograph.py
+0
-7
python/dgl/sparse.py
python/dgl/sparse.py
+0
-29
No files found.
python/dgl/backend/mxnet/sparse.py
View file @
878acdb0
import
mxnet
as
mx
import
numpy
as
np
from
mxnet
import
nd
from
...sparse
import
_gspmm
,
_gsddmm
,
_segment_reduce
,
_bwd_segment_cmp
,
_reverse
from
...sparse
import
_gspmm
,
_gsddmm
,
_segment_reduce
,
_bwd_segment_cmp
from
...base
import
dgl_warning
,
is_all
,
ALL
from
.tensor
import
asnumpy
,
copy_to
,
zerocopy_from_numpy
,
context
,
to_backend_ctx
...
...
@@ -132,7 +132,7 @@ class GSpMM(mx.autograd.Function):
X
,
Y
,
argX
,
argY
=
self
.
saved_tensors
gidx
,
op
,
reduce_op
=
self
.
gidx
,
self
.
op
,
self
.
reduce_op
if
op
!=
'copy_rhs'
:
g_rev
=
_
reverse
(
gidx
)
g_rev
=
gidx
.
reverse
()
if
reduce_op
==
'sum'
:
if
op
in
[
'mul'
,
'div'
]:
dX
=
_gspmm
(
g_rev
,
'mul'
,
'sum'
,
dZ
,
_muldiv
(
op
,
Y
))[
0
]
...
...
@@ -215,7 +215,7 @@ class GSDDMM(mx.autograd.Function):
lhs_target
,
rhs_target
=
self
.
lhs_target
,
self
.
rhs_target
if
op
!=
'copy_rhs'
:
if
lhs_target
in
[
'u'
,
'v'
]:
_gidx
=
gidx
if
self
.
lhs_target
==
'v'
else
_
reverse
(
gidx
)
_gidx
=
gidx
if
self
.
lhs_target
==
'v'
else
gidx
.
reverse
()
if
op
in
[
'add'
,
'sub'
,
'copy_lhs'
]:
dX
=
_gspmm
(
_gidx
,
'copy_rhs'
,
'sum'
,
None
,
dZ
)[
0
]
else
:
# mul, div, dot
...
...
@@ -235,7 +235,7 @@ class GSDDMM(mx.autograd.Function):
dX
=
nd
.
zeros_like
(
X
)
if
op
!=
'copy_lhs'
:
if
self
.
rhs_target
in
[
'u'
,
'v'
]:
_gidx
=
gidx
if
rhs_target
==
'v'
else
_
reverse
(
gidx
)
_gidx
=
gidx
if
rhs_target
==
'v'
else
gidx
.
reverse
()
if
op
in
[
'add'
,
'sub'
,
'copy_rhs'
]:
dY
=
_gspmm
(
_gidx
,
'copy_rhs'
,
'sum'
,
None
,
_addsub
(
op
,
dZ
))[
0
]
else
:
# mul, div, dot
...
...
@@ -277,7 +277,7 @@ class EdgeSoftmax(mx.autograd.Function):
if
not
is_all
(
eids
):
gidx
=
gidx
.
edge_subgraph
([
eids
],
True
).
graph
if
norm_by
==
'src'
:
gidx
=
_
reverse
(
gidx
)
gidx
=
gidx
.
reverse
()
self
.
gidx
=
gidx
def
forward
(
self
,
score
):
...
...
python/dgl/backend/pytorch/sparse.py
View file @
878acdb0
import
torch
as
th
from
distutils.version
import
LooseVersion
from
...base
import
is_all
,
ALL
from
...sparse
import
_gspmm
,
_gsddmm
,
_segment_reduce
,
_bwd_segment_cmp
,
_reverse
from
...sparse
import
_gspmm
,
_gsddmm
,
_segment_reduce
,
_bwd_segment_cmp
if
LooseVersion
(
th
.
__version__
)
>=
LooseVersion
(
"1.6.0"
):
from
torch.cuda.amp
import
custom_fwd
,
custom_bwd
...
...
@@ -27,6 +27,34 @@ else:
__all__
=
[
'gspmm'
,
'gsddmm'
,
'edge_softmax'
,
'segment_reduce'
]
_inverse_format
=
{
'coo'
:
'coo'
,
'csr'
:
'csc'
,
'csc'
:
'csr'
}
def
_reverse
(
gidx
):
"""Reverse the given graph index while retaining its formats.
Parameters
----------
gidx: HeteroGraphIndex
Return
------
HeteroGraphIndex
"""
g_rev
=
gidx
.
reverse
()
original_formats_dict
=
gidx
.
formats
()
original_formats
=
original_formats_dict
[
'created'
]
+
\
original_formats_dict
[
'not created'
]
g_rev
=
g_rev
.
formats
([
_inverse_format
[
fmt
]
for
fmt
in
original_formats
])
return
g_rev
def
_reduce_grad
(
grad
,
shape
):
"""Reduce gradient on the broadcast dimension
If there is broadcast in forward pass, gradients need to be reduced on
...
...
python/dgl/backend/tensorflow/sparse.py
View file @
878acdb0
...
...
@@ -2,7 +2,7 @@ import tensorflow as tf
import
numpy
as
np
from
.tensor
import
tensor
,
copy_to
,
context
,
asnumpy
,
zerocopy_from_numpy
from
...base
import
is_all
,
ALL
from
...sparse
import
_gspmm
,
_gsddmm
,
_segment_reduce
,
_bwd_segment_cmp
,
_reverse
from
...sparse
import
_gspmm
,
_gsddmm
,
_segment_reduce
,
_bwd_segment_cmp
__all__
=
[
'gspmm'
,
'gsddmm'
,
'edge_softmax'
,
'segment_reduce'
]
...
...
@@ -110,7 +110,7 @@ def gspmm_real(gidx, op, reduce_op, X, Y):
def
grad
(
dZ
):
dZ
=
tensor
(
dZ
)
if
op
!=
'copy_rhs'
:
g_rev
=
_
reverse
(
gidx
)
g_rev
=
gidx
.
reverse
()
if
reduce_op
==
'sum'
:
if
op
in
[
'mul'
,
'div'
]:
dX
=
_gspmm
(
g_rev
,
'mul'
,
'sum'
,
dZ
,
_muldiv
(
op
,
Y
))[
0
]
...
...
@@ -172,7 +172,7 @@ def gsddmm_real(gidx, op, X, Y, lhs_target, rhs_target):
def
grad
(
dZ
):
if
op
!=
'copy_rhs'
:
if
lhs_target
in
[
'u'
,
'v'
]:
_gidx
=
gidx
if
lhs_target
==
'v'
else
_
reverse
(
gidx
)
_gidx
=
gidx
if
lhs_target
==
'v'
else
gidx
.
reverse
()
if
op
in
[
'add'
,
'sub'
,
'copy_lhs'
]:
dX
=
_gspmm
(
_gidx
,
'copy_rhs'
,
'sum'
,
None
,
dZ
)[
0
]
else
:
# mul, div, dot
...
...
@@ -192,7 +192,7 @@ def gsddmm_real(gidx, op, X, Y, lhs_target, rhs_target):
dX
=
tf
.
zeros_like
(
X
)
if
op
!=
'copy_lhs'
:
if
rhs_target
in
[
'u'
,
'v'
]:
_gidx
=
gidx
if
rhs_target
==
'v'
else
_
reverse
(
gidx
)
_gidx
=
gidx
if
rhs_target
==
'v'
else
gidx
.
reverse
()
if
op
in
[
'add'
,
'sub'
,
'copy_rhs'
]:
dY
=
_gspmm
(
_gidx
,
'copy_rhs'
,
'sum'
,
None
,
_addsub
(
op
,
dZ
))[
0
]
else
:
# mul, div, dot
...
...
@@ -233,7 +233,7 @@ def edge_softmax_real(gidx, score, eids=ALL, norm_by='dst'):
if
not
is_all
(
eids
):
gidx
=
gidx
.
edge_subgraph
([
eids
],
True
).
graph
if
norm_by
==
'src'
:
gidx
=
_
reverse
(
gidx
)
gidx
=
gidx
.
reverse
()
score_max
=
_gspmm
(
gidx
,
'copy_rhs'
,
'max'
,
None
,
score
)[
0
]
score
=
tf
.
math
.
exp
(
_gsddmm
(
gidx
,
'sub'
,
score
,
score_max
,
'e'
,
'v'
))
score_sum
=
_gspmm
(
gidx
,
'copy_rhs'
,
'sum'
,
None
,
score
)[
0
]
...
...
python/dgl/heterograph.py
View file @
878acdb0
...
...
@@ -5449,13 +5449,6 @@ class DGLHeteroGraph(object):
>>> # Only allowed formats will be displayed in the status query
>>> csr_g.formats()
{'created': ['csr'], 'not created': []}
Notes
-----
DGL will create sparse formats (only constrained to the allowed formats, i.e.
created formats and not created formats) on-the-fly during the training of Graph
Neural Networks. Once a format was created, it would be cached and reused until
user changes the graph structure.
"""
if
formats
is
None
:
# Return the format information
...
...
python/dgl/sparse.py
View file @
878acdb0
...
...
@@ -7,8 +7,6 @@ from ._ffi.function import _init_api
from
.base
import
DGLError
from
.
import
backend
as
F
__all__
=
[
'_gspmm'
,
'_gsddmm'
,
'_segment_reduce'
,
'_bwd_segment_cmp'
,
'_reverse'
]
def
infer_broadcast_shape
(
op
,
shp1
,
shp2
):
r
"""Check the shape validity, and infer the output shape given input shape and operator.
...
...
@@ -67,33 +65,6 @@ def to_dgl_nd_for_write(x):
return
nd
.
NULL
[
'int64'
]
if
x
is
None
else
F
.
zerocopy_to_dgl_ndarray_for_write
(
x
)
inverse_format
=
{
'coo'
:
'coo'
,
'csr'
:
'csc'
,
'csc'
:
'csr'
}
def
_reverse
(
gidx
):
"""Reverse the given graph index while retaining its formats.
``dgl.reverse`` would not keep graph format information by default.
Parameters
----------
gidx: HeteroGraphIndex
Return
------
HeteroGraphIndex
"""
g_rev
=
gidx
.
reverse
()
original_formats_dict
=
gidx
.
formats
()
original_formats
=
original_formats_dict
[
'created'
]
+
\
original_formats_dict
[
'not created'
]
g_rev
=
g_rev
.
formats
([
inverse_format
[
fmt
]
for
fmt
in
original_formats
])
return
g_rev
target_mapping
=
{
'u'
:
0
,
'e'
:
1
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment