Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
fb3c0709
Unverified
Commit
fb3c0709
authored
Jan 28, 2021
by
Zihao Ye
Committed by
GitHub
Jan 28, 2021
Browse files
Revert part of #2563 (#2584)
parent
878acdb0
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
4 additions
and
32 deletions
+4
-32
python/dgl/backend/pytorch/sparse.py
python/dgl/backend/pytorch/sparse.py
+4
-32
No files found.
python/dgl/backend/pytorch/sparse.py
View file @
fb3c0709
...
...
@@ -27,34 +27,6 @@ else:
__all__
=
[
'gspmm'
,
'gsddmm'
,
'edge_softmax'
,
'segment_reduce'
]
_inverse_format
=
{
'coo'
:
'coo'
,
'csr'
:
'csc'
,
'csc'
:
'csr'
}
def
_reverse
(
gidx
):
"""Reverse the given graph index while retaining its formats.
Parameters
----------
gidx: HeteroGraphIndex
Return
------
HeteroGraphIndex
"""
g_rev
=
gidx
.
reverse
()
original_formats_dict
=
gidx
.
formats
()
original_formats
=
original_formats_dict
[
'created'
]
+
\
original_formats_dict
[
'not created'
]
g_rev
=
g_rev
.
formats
([
_inverse_format
[
fmt
]
for
fmt
in
original_formats
])
return
g_rev
def
_reduce_grad
(
grad
,
shape
):
"""Reduce gradient on the broadcast dimension
If there is broadcast in forward pass, gradients need to be reduced on
...
...
@@ -123,7 +95,7 @@ class GSpMM(th.autograd.Function):
gidx
,
op
,
reduce_op
=
ctx
.
backward_cache
X
,
Y
,
argX
,
argY
=
ctx
.
saved_tensors
if
op
!=
'copy_rhs'
and
ctx
.
needs_input_grad
[
3
]:
g_rev
=
_
reverse
(
gidx
)
g_rev
=
gidx
.
reverse
()
if
reduce_op
==
'sum'
:
if
op
in
[
'mul'
,
'div'
]:
dX
=
gspmm
(
g_rev
,
'mul'
,
'sum'
,
dZ
,
_muldiv
(
op
,
Y
))
...
...
@@ -186,7 +158,7 @@ class GSDDMM(th.autograd.Function):
X
,
Y
=
ctx
.
saved_tensors
if
op
!=
'copy_rhs'
and
ctx
.
needs_input_grad
[
2
]:
if
lhs_target
in
[
'u'
,
'v'
]:
_gidx
=
gidx
if
lhs_target
==
'v'
else
_
reverse
(
gidx
)
_gidx
=
gidx
if
lhs_target
==
'v'
else
gidx
.
reverse
()
if
op
in
[
'add'
,
'sub'
,
'copy_lhs'
]:
dX
=
gspmm
(
_gidx
,
'copy_rhs'
,
'sum'
,
None
,
dZ
)
else
:
# mul, div, dot
...
...
@@ -206,7 +178,7 @@ class GSDDMM(th.autograd.Function):
dX
=
None
if
op
!=
'copy_lhs'
and
ctx
.
needs_input_grad
[
3
]:
if
rhs_target
in
[
'u'
,
'v'
]:
_gidx
=
gidx
if
rhs_target
==
'v'
else
_
reverse
(
gidx
)
_gidx
=
gidx
if
rhs_target
==
'v'
else
gidx
.
reverse
()
if
op
in
[
'add'
,
'sub'
,
'copy_rhs'
]:
dY
=
gspmm
(
_gidx
,
'copy_rhs'
,
'sum'
,
None
,
_addsub
(
op
,
dZ
))
else
:
# mul, div, dot
...
...
@@ -253,7 +225,7 @@ class EdgeSoftmax(th.autograd.Function):
if
not
is_all
(
eids
):
gidx
=
gidx
.
edge_subgraph
([
eids
],
True
).
graph
if
norm_by
==
'src'
:
gidx
=
_
reverse
(
gidx
)
gidx
=
gidx
.
reverse
()
score_max
=
_gspmm
(
gidx
,
'copy_rhs'
,
'max'
,
None
,
score
)[
0
]
score
=
th
.
exp
(
_gsddmm
(
gidx
,
'sub'
,
score
,
score_max
,
'e'
,
'v'
))
score_sum
=
_gspmm
(
gidx
,
'copy_rhs'
,
'sum'
,
None
,
score
)[
0
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment