Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
2efdaa5d
Unverified
Commit
2efdaa5d
authored
Jul 14, 2022
by
Quan (Andy) Gan
Committed by
GitHub
Jul 14, 2022
Browse files
[Bug] Revert clearing backward cache for retain_graph flag (#4249)
Co-authored-by:
Minjie Wang
<
wmjlyjemaine@gmail.com
>
parent
5c76e47f
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
0 additions
and
17 deletions
+0
-17
python/dgl/backend/pytorch/sparse.py
python/dgl/backend/pytorch/sparse.py
+0
-13
python/dgl/backend/pytorch/tensor.py
python/dgl/backend/pytorch/tensor.py
+0
-4
No files found.
python/dgl/backend/pytorch/sparse.py
View file @
2efdaa5d
...
...
@@ -127,7 +127,6 @@ class GSpMM(th.autograd.Function):
@
custom_bwd
def
backward
(
ctx
,
dZ
):
gidx
,
op
,
reduce_op
,
X_shape
,
Y_shape
,
dtype
,
device
,
reduce_last
=
ctx
.
backward_cache
ctx
.
backward_cache
=
None
X
,
Y
,
argX
,
argY
=
ctx
.
saved_tensors
if
op
!=
'copy_rhs'
and
ctx
.
needs_input_grad
[
3
]:
g_rev
=
gidx
.
reverse
()
...
...
@@ -207,7 +206,6 @@ class GSpMM_hetero(th.autograd.Function):
@
custom_bwd
def
backward
(
ctx
,
*
dZ
):
gidx
,
op
,
reduce_op
,
X_shape
,
Y_shape
,
dtype
,
device
,
reduce_last
,
X_len
=
ctx
.
backward_cache
ctx
.
backward_cache
=
None
num_ntypes
=
gidx
.
number_of_ntypes
()
feats
=
ctx
.
saved_tensors
[:
-
(
4
*
num_ntypes
)]
argX
=
ctx
.
saved_tensors
[
-
(
4
*
num_ntypes
):
-
(
3
*
num_ntypes
)]
...
...
@@ -305,7 +303,6 @@ class GSDDMM(th.autograd.Function):
@
custom_bwd
def
backward
(
ctx
,
dZ
):
gidx
,
op
,
lhs_target
,
rhs_target
,
X_shape
,
Y_shape
=
ctx
.
backward_cache
ctx
.
backward_cache
=
None
X
,
Y
=
ctx
.
saved_tensors
if
op
!=
'copy_rhs'
and
ctx
.
needs_input_grad
[
2
]:
if
lhs_target
in
[
'u'
,
'v'
]:
...
...
@@ -373,7 +370,6 @@ class GSDDMM_hetero(th.autograd.Function):
# TODO(Israt): Implement the complete backward operator
def
backward
(
ctx
,
*
dZ
):
gidx
,
op
,
lhs_target
,
rhs_target
,
X_shape
,
Y_shape
,
X_len
=
ctx
.
backward_cache
ctx
.
backward_cache
=
None
feats
=
ctx
.
saved_tensors
X
,
Y
=
feats
[:
X_len
],
feats
[
X_len
:]
if
op
!=
'copy_rhs'
and
any
([
x
is
not
None
for
x
in
X
]):
...
...
@@ -484,8 +480,6 @@ class EdgeSoftmax(th.autograd.Function):
return grad_score.data
"""
gidx
=
ctx
.
backward_cache
# See https://github.com/dmlc/dgl/pull/3386
ctx
.
backward_cache
=
None
out
,
=
ctx
.
saved_tensors
sds
=
out
*
grad_out
#Note: Now _edge_softmax_backward op only supports CPU
...
...
@@ -554,8 +548,6 @@ class EdgeSoftmax_hetero(th.autograd.Function):
return grad_score.data
"""
gidx
=
ctx
.
backward_cache
# See https://github.com/dmlc/dgl/pull/3386
ctx
.
backward_cache
=
None
u_len
=
gidx
.
number_of_ntypes
()
e_len
=
gidx
.
number_of_etypes
()
lhs
=
[
None
]
*
u_len
...
...
@@ -582,8 +574,6 @@ class SegmentReduce(th.autograd.Function):
@
custom_bwd
def
backward
(
ctx
,
dy
):
op
=
ctx
.
backward_cache
# See https://github.com/dmlc/dgl/pull/3386
ctx
.
backward_cache
=
None
arg
,
offsets
=
ctx
.
saved_tensors
m
=
offsets
[
-
1
].
item
()
if
op
==
'sum'
:
...
...
@@ -630,7 +620,6 @@ class CSRMM(th.autograd.Function):
def
backward
(
ctx
,
dnrows
,
dncols
,
dC_indptr
,
dC_indices
,
dC_eids
,
dC_weights
):
# Only the last argument is meaningful.
gidxA
,
gidxB
,
gidxC
=
ctx
.
backward_cache
ctx
.
backward_cache
=
None
A_weights
,
B_weights
=
ctx
.
saved_tensors
dgidxA
,
dA_weights
=
csrmm
(
gidxC
,
dC_weights
,
gidxB
.
reverse
(),
B_weights
,
gidxA
.
number_of_ntypes
())
...
...
@@ -657,7 +646,6 @@ class CSRSum(th.autograd.Function):
def
backward
(
ctx
,
dnrows
,
dncols
,
dC_indptr
,
dC_indices
,
dC_eids
,
dC_weights
):
# Only the last argument is meaningful.
gidxs
,
gidxC
=
ctx
.
backward_cache
ctx
.
backward_cache
=
None
return
(
None
,)
+
tuple
(
csrmask
(
gidxC
,
dC_weights
,
gidx
)
for
gidx
in
gidxs
)
...
...
@@ -670,7 +658,6 @@ class CSRMask(th.autograd.Function):
@
staticmethod
def
backward
(
ctx
,
dB_weights
):
gidxA
,
gidxB
=
ctx
.
backward_cache
ctx
.
backward_cache
=
None
return
None
,
csrmask
(
gidxB
,
dB_weights
,
gidxA
),
None
...
...
python/dgl/backend/pytorch/tensor.py
View file @
2efdaa5d
...
...
@@ -418,8 +418,6 @@ class BinaryReduce(th.autograd.Function):
def
backward
(
ctx
,
grad_out
):
reducer
,
binary_op
,
graph
,
lhs
,
rhs
,
lhs_map
,
rhs_map
,
out_map
,
\
feat_shape
,
degs
=
ctx
.
backward_cache
# See https://github.com/dmlc/dgl/pull/3386
ctx
.
backward_cache
=
None
lhs_data
,
rhs_data
,
out_data
=
ctx
.
saved_tensors
lhs_data_nd
=
zerocopy_to_dgl_ndarray
(
lhs_data
)
rhs_data_nd
=
zerocopy_to_dgl_ndarray
(
rhs_data
)
...
...
@@ -497,8 +495,6 @@ class CopyReduce(th.autograd.Function):
@
staticmethod
def
backward
(
ctx
,
grad_out
):
reducer
,
graph
,
target
,
in_map
,
out_map
,
degs
=
ctx
.
backward_cache
# See https://github.com/dmlc/dgl/pull/3386
ctx
.
backward_cache
=
None
in_data
,
out_data
=
ctx
.
saved_tensors
in_data_nd
=
zerocopy_to_dgl_ndarray
(
in_data
)
out_data_nd
=
zerocopy_to_dgl_ndarray
(
out_data
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment