Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
fb3c0709
Unverified
Commit
fb3c0709
authored
Jan 28, 2021
by
Zihao Ye
Committed by
GitHub
Jan 28, 2021
Browse files
Revert part of #2563 (#2584)
parent
878acdb0
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
4 additions
and
32 deletions
+4
-32
python/dgl/backend/pytorch/sparse.py
python/dgl/backend/pytorch/sparse.py
+4
-32
No files found.
python/dgl/backend/pytorch/sparse.py
View file @
fb3c0709
...
@@ -27,34 +27,6 @@ else:
...
@@ -27,34 +27,6 @@ else:
__all__
=
[
'gspmm'
,
'gsddmm'
,
'edge_softmax'
,
'segment_reduce'
]
__all__
=
[
'gspmm'
,
'gsddmm'
,
'edge_softmax'
,
'segment_reduce'
]
_inverse_format
=
{
'coo'
:
'coo'
,
'csr'
:
'csc'
,
'csc'
:
'csr'
}
def
_reverse
(
gidx
):
"""Reverse the given graph index while retaining its formats.
Parameters
----------
gidx: HeteroGraphIndex
Return
------
HeteroGraphIndex
"""
g_rev
=
gidx
.
reverse
()
original_formats_dict
=
gidx
.
formats
()
original_formats
=
original_formats_dict
[
'created'
]
+
\
original_formats_dict
[
'not created'
]
g_rev
=
g_rev
.
formats
([
_inverse_format
[
fmt
]
for
fmt
in
original_formats
])
return
g_rev
def
_reduce_grad
(
grad
,
shape
):
def
_reduce_grad
(
grad
,
shape
):
"""Reduce gradient on the broadcast dimension
"""Reduce gradient on the broadcast dimension
If there is broadcast in forward pass, gradients need to be reduced on
If there is broadcast in forward pass, gradients need to be reduced on
...
@@ -123,7 +95,7 @@ class GSpMM(th.autograd.Function):
...
@@ -123,7 +95,7 @@ class GSpMM(th.autograd.Function):
gidx
,
op
,
reduce_op
=
ctx
.
backward_cache
gidx
,
op
,
reduce_op
=
ctx
.
backward_cache
X
,
Y
,
argX
,
argY
=
ctx
.
saved_tensors
X
,
Y
,
argX
,
argY
=
ctx
.
saved_tensors
if
op
!=
'copy_rhs'
and
ctx
.
needs_input_grad
[
3
]:
if
op
!=
'copy_rhs'
and
ctx
.
needs_input_grad
[
3
]:
g_rev
=
_
reverse
(
gidx
)
g_rev
=
gidx
.
reverse
()
if
reduce_op
==
'sum'
:
if
reduce_op
==
'sum'
:
if
op
in
[
'mul'
,
'div'
]:
if
op
in
[
'mul'
,
'div'
]:
dX
=
gspmm
(
g_rev
,
'mul'
,
'sum'
,
dZ
,
_muldiv
(
op
,
Y
))
dX
=
gspmm
(
g_rev
,
'mul'
,
'sum'
,
dZ
,
_muldiv
(
op
,
Y
))
...
@@ -186,7 +158,7 @@ class GSDDMM(th.autograd.Function):
...
@@ -186,7 +158,7 @@ class GSDDMM(th.autograd.Function):
X
,
Y
=
ctx
.
saved_tensors
X
,
Y
=
ctx
.
saved_tensors
if
op
!=
'copy_rhs'
and
ctx
.
needs_input_grad
[
2
]:
if
op
!=
'copy_rhs'
and
ctx
.
needs_input_grad
[
2
]:
if
lhs_target
in
[
'u'
,
'v'
]:
if
lhs_target
in
[
'u'
,
'v'
]:
_gidx
=
gidx
if
lhs_target
==
'v'
else
_
reverse
(
gidx
)
_gidx
=
gidx
if
lhs_target
==
'v'
else
gidx
.
reverse
()
if
op
in
[
'add'
,
'sub'
,
'copy_lhs'
]:
if
op
in
[
'add'
,
'sub'
,
'copy_lhs'
]:
dX
=
gspmm
(
_gidx
,
'copy_rhs'
,
'sum'
,
None
,
dZ
)
dX
=
gspmm
(
_gidx
,
'copy_rhs'
,
'sum'
,
None
,
dZ
)
else
:
# mul, div, dot
else
:
# mul, div, dot
...
@@ -206,7 +178,7 @@ class GSDDMM(th.autograd.Function):
...
@@ -206,7 +178,7 @@ class GSDDMM(th.autograd.Function):
dX
=
None
dX
=
None
if
op
!=
'copy_lhs'
and
ctx
.
needs_input_grad
[
3
]:
if
op
!=
'copy_lhs'
and
ctx
.
needs_input_grad
[
3
]:
if
rhs_target
in
[
'u'
,
'v'
]:
if
rhs_target
in
[
'u'
,
'v'
]:
_gidx
=
gidx
if
rhs_target
==
'v'
else
_
reverse
(
gidx
)
_gidx
=
gidx
if
rhs_target
==
'v'
else
gidx
.
reverse
()
if
op
in
[
'add'
,
'sub'
,
'copy_rhs'
]:
if
op
in
[
'add'
,
'sub'
,
'copy_rhs'
]:
dY
=
gspmm
(
_gidx
,
'copy_rhs'
,
'sum'
,
None
,
_addsub
(
op
,
dZ
))
dY
=
gspmm
(
_gidx
,
'copy_rhs'
,
'sum'
,
None
,
_addsub
(
op
,
dZ
))
else
:
# mul, div, dot
else
:
# mul, div, dot
...
@@ -253,7 +225,7 @@ class EdgeSoftmax(th.autograd.Function):
...
@@ -253,7 +225,7 @@ class EdgeSoftmax(th.autograd.Function):
if
not
is_all
(
eids
):
if
not
is_all
(
eids
):
gidx
=
gidx
.
edge_subgraph
([
eids
],
True
).
graph
gidx
=
gidx
.
edge_subgraph
([
eids
],
True
).
graph
if
norm_by
==
'src'
:
if
norm_by
==
'src'
:
gidx
=
_
reverse
(
gidx
)
gidx
=
gidx
.
reverse
()
score_max
=
_gspmm
(
gidx
,
'copy_rhs'
,
'max'
,
None
,
score
)[
0
]
score_max
=
_gspmm
(
gidx
,
'copy_rhs'
,
'max'
,
None
,
score
)[
0
]
score
=
th
.
exp
(
_gsddmm
(
gidx
,
'sub'
,
score
,
score_max
,
'e'
,
'v'
))
score
=
th
.
exp
(
_gsddmm
(
gidx
,
'sub'
,
score
,
score_max
,
'e'
,
'v'
))
score_sum
=
_gspmm
(
gidx
,
'copy_rhs'
,
'sum'
,
None
,
score
)[
0
]
score_sum
=
_gspmm
(
gidx
,
'copy_rhs'
,
'sum'
,
None
,
score
)[
0
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment