Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
76af2a2e
Unverified
Commit
76af2a2e
authored
Aug 17, 2021
by
Zihao Ye
Committed by
GitHub
Aug 17, 2021
Browse files
[perf] Remove activation cache if not required. (#3258)
* upd * fix * upd
parent
ac01e880
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
130 additions
and
55 deletions
+130
-55
python/dgl/backend/pytorch/sparse.py
python/dgl/backend/pytorch/sparse.py
+130
-55
No files found.
python/dgl/backend/pytorch/sparse.py
View file @
76af2a2e
...
...
@@ -66,21 +66,54 @@ def _need_reduce_last_dim(ufeat, efeat):
"""Indicates whether to reduce the last dimension on edges
in the backward pass of spmm,
if so, use dot instead of mul."""
if
ufeat
is
None
or
efeat
is
None
:
return
False
ushp
=
ufeat
.
shape
eshp
=
efeat
.
shape
return
ushp
[
1
:
-
1
]
==
eshp
[
1
:
-
1
]
and
eshp
[
-
1
]
==
1
and
ushp
[
-
1
]
>
1
def
_
muldiv
(
op
,
x
):
return
1.
/
x
if
op
==
'div'
else
x
def
_
expand
(
x
,
shape
):
return
x
.
expand
(
-
1
,
*
shape
)
def
_addsub
(
op
,
x
):
return
-
x
if
op
==
'sub'
else
x
def
spmm_cache_X
(
binary_op
,
reduce_op
,
req_grad_X
,
req_grad_Y
):
"""Rules to identify whether to cache X in SpMM forward stage."""
if
binary_op
!=
'copy_lhs'
and
req_grad_Y
:
if
reduce_op
==
'sum'
:
return
True
else
:
if
binary_op
==
'mul'
:
return
True
return
False
def
_expand
(
x
,
shape
):
return
x
.
expand
(
-
1
,
*
shape
)
def
spmm_cache_Y
(
binary_op
,
reduce_op
,
req_grad_X
,
req_grad_Y
):
"""Rules to identify whether to cache Y in SpMM forward stage."""
if
binary_op
!=
'copy_rhs'
and
req_grad_X
:
if
reduce_op
==
'sum'
:
if
binary_op
in
[
'mul'
,
'add'
]:
return
True
else
:
if
binary_op
==
'mul'
:
return
True
return
False
def
spmm_cache_argX
(
binary_op
,
reduce_op
,
req_grad_X
,
req_grad_Y
):
"""Rules to identify whether to cache argX in SpMM forward stage."""
if
req_grad_X
:
if
reduce_op
in
[
'min'
,
'max'
]:
return
True
return
False
def
spmm_cache_argY
(
binary_op
,
reduce_op
,
req_grad_X
,
req_grad_Y
):
"""Rules to identify whether to cache argY in SpMM forward stage."""
if
req_grad_Y
:
if
reduce_op
in
[
'min'
,
'max'
]:
return
True
return
False
class
GSpMM
(
th
.
autograd
.
Function
):
...
...
@@ -88,58 +121,69 @@ class GSpMM(th.autograd.Function):
@
custom_fwd
(
cast_inputs
=
th
.
float16
)
def
forward
(
ctx
,
gidx
,
op
,
reduce_op
,
X
,
Y
):
out
,
(
argX
,
argY
)
=
_gspmm
(
gidx
,
op
,
reduce_op
,
X
,
Y
)
ctx
.
backward_cache
=
gidx
,
op
,
reduce_op
reduce_last
=
_need_reduce_last_dim
(
X
,
Y
)
X_shape
=
X
.
shape
if
X
is
not
None
else
None
Y_shape
=
Y
.
shape
if
Y
is
not
None
else
None
dtype
=
X
.
dtype
if
X
is
not
None
else
Y
.
dtype
device
=
X
.
device
if
X
is
not
None
else
Y
.
device
ctx
.
backward_cache
=
gidx
,
op
,
reduce_op
,
X_shape
,
Y_shape
,
dtype
,
device
,
reduce_last
req_grad_X
=
X
.
requires_grad
if
X
is
not
None
else
False
req_grad_Y
=
Y
.
requires_grad
if
Y
is
not
None
else
False
if
not
spmm_cache_X
(
op
,
reduce_op
,
req_grad_X
,
req_grad_Y
):
X
=
None
if
not
spmm_cache_Y
(
op
,
reduce_op
,
req_grad_X
,
req_grad_Y
):
Y
=
None
if
not
spmm_cache_argX
(
op
,
reduce_op
,
req_grad_X
,
req_grad_Y
):
argX
=
None
if
not
spmm_cache_argY
(
op
,
reduce_op
,
req_grad_X
,
req_grad_Y
):
argY
=
None
ctx
.
save_for_backward
(
X
,
Y
,
argX
,
argY
)
return
out
@
staticmethod
@
custom_bwd
def
backward
(
ctx
,
dZ
):
gidx
,
op
,
reduce_op
=
ctx
.
backward_cache
gidx
,
op
,
reduce_op
,
X_shape
,
Y_shape
,
dtype
,
device
,
reduce_last
=
ctx
.
backward_cache
X
,
Y
,
argX
,
argY
=
ctx
.
saved_tensors
if
op
!=
'copy_rhs'
and
ctx
.
needs_input_grad
[
3
]:
g_rev
=
gidx
.
reverse
()
if
reduce_op
==
'sum'
:
if
op
in
[
'mul'
,
'div'
]
:
dX
=
gspmm
(
g_rev
,
'mul'
,
'sum'
,
dZ
,
_muldiv
(
op
,
Y
)
)
elif
op
in
[
'add'
,
'sub'
]
:
if
op
==
'mul'
:
dX
=
gspmm
(
g_rev
,
'mul'
,
'sum'
,
dZ
,
Y
)
elif
op
==
'add'
:
dX
=
gspmm
(
g_rev
,
'copy_lhs'
,
'sum'
,
dZ
,
Y
)
elif
op
==
'copy_lhs'
:
dX
=
gspmm
(
g_rev
,
'copy_lhs'
,
'sum'
,
dZ
,
None
)
else
:
# max/min
dX
=
th
.
zeros
((
X
.
shape
[
0
],)
+
dZ
.
shape
[
1
:],
dtype
=
X
.
dtype
,
device
=
X
.
device
)
if
op
in
[
'mul'
,
'div'
]
:
grad
=
_muldiv
(
op
,
_expand
(
Y
,
dZ
.
shape
[
1
:]).
gather
(
0
,
argY
.
long
())
)
*
dZ
dX
=
th
.
zeros
((
X
_
shape
[
0
],)
+
dZ
.
shape
[
1
:],
dtype
=
dtype
,
device
=
device
)
if
op
==
'mul'
:
grad
=
_expand
(
Y
,
dZ
.
shape
[
1
:]).
gather
(
0
,
argY
.
long
())
*
dZ
dX
.
scatter_add_
(
0
,
argX
.
long
(),
grad
)
elif
op
in
[
'add'
,
'sub'
,
'copy_lhs'
]:
elif
op
in
[
'add'
,
'copy_lhs'
]:
dX
.
scatter_add_
(
0
,
argX
.
long
(),
dZ
)
dX
=
_reduce_grad
(
dX
,
X
.
shape
)
dX
=
_reduce_grad
(
dX
,
X
_
shape
)
else
:
# X has not gradient
dX
=
None
if
op
!=
'copy_lhs'
and
ctx
.
needs_input_grad
[
4
]:
if
reduce_op
==
'sum'
:
if
op
==
'mul'
and
_need_
reduce_last
_dim
(
X
,
Y
)
:
if
op
==
'mul'
and
reduce_last
:
dY
=
gsddmm
(
gidx
,
'dot'
,
X
,
dZ
)
elif
op
in
[
'mul'
,
'div'
]
:
elif
op
==
'mul'
:
dY
=
gsddmm
(
gidx
,
'mul'
,
X
,
dZ
)
if
op
==
'div'
:
dY
=
-
dY
/
(
Y
**
2
)
elif
op
in
[
'add'
,
'sub'
,
'copy_rhs'
]:
dY
=
gsddmm
(
gidx
,
'copy_rhs'
,
X
,
_addsub
(
op
,
dZ
))
elif
op
in
[
'add'
,
'copy_rhs'
]:
dY
=
gsddmm
(
gidx
,
'copy_rhs'
,
X
,
dZ
)
else
:
# max/min
dY
=
th
.
zeros
((
Y
.
shape
[
0
],)
+
dZ
.
shape
[
1
:],
dtype
=
Y
.
dtype
,
device
=
Y
.
device
)
if
op
in
[
'mul'
,
'div'
]
:
dY
=
th
.
zeros
((
Y
_
shape
[
0
],)
+
dZ
.
shape
[
1
:],
dtype
=
dtype
,
device
=
device
)
if
op
==
'mul'
:
grad
=
_expand
(
X
,
dZ
.
shape
[
1
:]).
gather
(
0
,
argX
.
long
())
*
dZ
dY
.
scatter_add_
(
0
,
argY
.
long
(),
grad
)
if
op
==
'div'
:
dY
=
-
dY
/
(
Y
**
2
)
elif
op
in
[
'add'
,
'sub'
,
'copy_rhs'
]:
dY
.
scatter_add_
(
0
,
argY
.
long
(),
_addsub
(
op
,
dZ
))
dY
=
_reduce_grad
(
dY
,
Y
.
shape
)
elif
op
in
[
'add'
,
'copy_rhs'
]:
dY
.
scatter_add_
(
0
,
argY
.
long
(),
dZ
)
dY
=
_reduce_grad
(
dY
,
Y_shape
)
else
:
# Y has no gradient
dY
=
None
return
None
,
None
,
None
,
dX
,
dY
...
...
@@ -178,7 +222,7 @@ class GSpMM_hetero(th.autograd.Function):
# TODO(Israt): implement other combinations of message and reduce functions
if
reduce_op
==
'sum'
:
if
op
in
[
'copy_rhs'
]:
tmp_Z
=
tuple
([
_addsub
(
op
,
dZ
[
i
]
)
if
dZ
[
i
]
is
not
None
else
None
tmp_Z
=
tuple
([
dZ
[
i
]
if
dZ
[
i
]
is
not
None
else
None
for
i
in
range
(
len
(
dZ
))])
tmp
=
tuple
(
X
+
tmp_Z
)
dY
=
gsddmm_hetero
(
g
,
'copy_rhs'
,
'u'
,
'v'
,
*
tmp
)
...
...
@@ -188,62 +232,81 @@ class GSpMM_hetero(th.autograd.Function):
dY
=
tuple
([
None
]
*
len
(
Y
))
return
(
None
,
None
,
None
)
+
dX
+
dY
def
sddmm_cache_X
(
op
,
req_grad_X
,
req_grad_Y
):
"""Rules to identify whether to cache X in SDDMM forward stage."""
if
op
in
[
'mul'
,
'dot'
]
and
req_grad_Y
:
return
True
return
False
def
sddmm_cache_Y
(
op
,
req_grad_X
,
req_grad_Y
):
"""Rules to identify whether to cache Y in SDDMM forward stage."""
if
op
in
[
'mul'
,
'dot'
]
and
req_grad_X
:
return
True
return
False
class
GSDDMM
(
th
.
autograd
.
Function
):
@
staticmethod
@
custom_fwd
(
cast_inputs
=
th
.
float16
)
def
forward
(
ctx
,
gidx
,
op
,
X
,
Y
,
lhs_target
,
rhs_target
):
out
=
_gsddmm
(
gidx
,
op
,
X
,
Y
,
lhs_target
,
rhs_target
)
ctx
.
backward_cache
=
gidx
,
op
,
lhs_target
,
rhs_target
X_shape
=
X
.
shape
if
X
is
not
None
else
None
Y_shape
=
Y
.
shape
if
Y
is
not
None
else
None
ctx
.
backward_cache
=
gidx
,
op
,
lhs_target
,
rhs_target
,
X_shape
,
Y_shape
req_grad_X
=
X
.
requires_grad
if
X
is
not
None
else
False
req_grad_Y
=
Y
.
requires_grad
if
Y
is
not
None
else
False
if
not
sddmm_cache_X
(
op
,
req_grad_X
,
req_grad_Y
):
X
=
None
if
not
sddmm_cache_Y
(
op
,
req_grad_X
,
req_grad_Y
):
Y
=
None
ctx
.
save_for_backward
(
X
,
Y
)
return
out
@
staticmethod
@
custom_bwd
def
backward
(
ctx
,
dZ
):
gidx
,
op
,
lhs_target
,
rhs_target
=
ctx
.
backward_cache
gidx
,
op
,
lhs_target
,
rhs_target
,
X_shape
,
Y_shape
=
ctx
.
backward_cache
X
,
Y
=
ctx
.
saved_tensors
if
op
!=
'copy_rhs'
and
ctx
.
needs_input_grad
[
2
]:
if
lhs_target
in
[
'u'
,
'v'
]:
_gidx
=
gidx
if
lhs_target
==
'v'
else
gidx
.
reverse
()
if
op
in
[
'add'
,
'sub'
,
'copy_lhs'
]:
if
op
in
[
'add'
,
'copy_lhs'
]:
dX
=
gspmm
(
_gidx
,
'copy_rhs'
,
'sum'
,
None
,
dZ
)
else
:
# mul,
div,
dot
else
:
# mul, dot
if
rhs_target
==
lhs_target
:
dX
=
gspmm
(
_gidx
,
'copy_rhs'
,
'sum'
,
None
,
dZ
)
*
_muldiv
(
op
,
Y
)
dX
=
gspmm
(
_gidx
,
'copy_rhs'
,
'sum'
,
None
,
dZ
)
*
Y
elif
rhs_target
==
'e'
:
dX
=
gspmm
(
_gidx
,
'copy_rhs'
,
'sum'
,
None
,
dZ
*
_muldiv
(
op
,
Y
)
)
dX
=
gspmm
(
_gidx
,
'copy_rhs'
,
'sum'
,
None
,
dZ
*
Y
)
else
:
# rhs_target = !lhs_target
dX
=
gspmm
(
_gidx
,
'mul'
,
'sum'
,
_muldiv
(
op
,
Y
)
,
dZ
)
dX
=
gspmm
(
_gidx
,
'mul'
,
'sum'
,
Y
,
dZ
)
else
:
# lhs_target == 'e'
if
op
in
[
'add'
,
'sub'
,
'copy_lhs'
]:
if
op
in
[
'add'
,
'copy_lhs'
]:
dX
=
dZ
else
:
# mul,
div,
dot
dX
=
gsddmm
(
gidx
,
'mul'
,
dZ
,
_muldiv
(
op
,
Y
)
,
'e'
,
rhs_target
)
dX
=
_reduce_grad
(
dX
,
X
.
shape
)
else
:
# mul, dot
dX
=
gsddmm
(
gidx
,
'mul'
,
dZ
,
Y
,
'e'
,
rhs_target
)
dX
=
_reduce_grad
(
dX
,
X
_
shape
)
else
:
dX
=
None
if
op
!=
'copy_lhs'
and
ctx
.
needs_input_grad
[
3
]:
if
rhs_target
in
[
'u'
,
'v'
]:
_gidx
=
gidx
if
rhs_target
==
'v'
else
gidx
.
reverse
()
if
op
in
[
'add'
,
'sub'
,
'copy_rhs'
]:
dY
=
gspmm
(
_gidx
,
'copy_rhs'
,
'sum'
,
None
,
_addsub
(
op
,
dZ
)
)
else
:
# mul,
div,
dot
if
op
in
[
'add'
,
'copy_rhs'
]:
dY
=
gspmm
(
_gidx
,
'copy_rhs'
,
'sum'
,
None
,
dZ
)
else
:
# mul, dot
if
lhs_target
==
rhs_target
:
dY
=
gspmm
(
_gidx
,
'copy_rhs'
,
'sum'
,
None
,
dZ
)
*
X
elif
lhs_target
==
'e'
:
dY
=
gspmm
(
_gidx
,
'copy_rhs'
,
'sum'
,
None
,
dZ
*
X
)
else
:
# rhs_target = !lhs_target
dY
=
gspmm
(
_gidx
,
'mul'
,
'sum'
,
X
,
dZ
)
if
op
==
'div'
:
dY
=
-
dY
/
(
Y
**
2
)
else
:
if
op
in
[
'add'
,
'sub'
,
'copy_rhs'
]:
dY
=
_addsub
(
op
,
dZ
)
else
:
# mul,
div,
dot
if
op
in
[
'add'
,
'copy_rhs'
]:
dY
=
dZ
else
:
# mul, dot
dY
=
gsddmm
(
gidx
,
'mul'
,
dZ
,
X
,
'e'
,
lhs_target
)
if
op
==
'div'
:
dY
=
-
dY
/
(
Y
**
2
)
dY
=
_reduce_grad
(
dY
,
Y
.
shape
)
dY
=
_reduce_grad
(
dY
,
Y_shape
)
else
:
dY
=
None
return
None
,
None
,
dX
,
dY
,
None
,
None
...
...
@@ -422,9 +485,21 @@ class CSRMask(th.autograd.Function):
def
gspmm
(
gidx
,
op
,
reduce_op
,
lhs_data
,
rhs_data
):
if
op
==
'sub'
:
op
=
'add'
rhs_data
=
-
rhs_data
if
op
==
'div'
:
op
=
'mul'
rhs_data
=
1.
/
rhs_data
return
GSpMM
.
apply
(
gidx
,
op
,
reduce_op
,
lhs_data
,
rhs_data
)
def
gsddmm
(
gidx
,
op
,
lhs_data
,
rhs_data
,
lhs_target
=
'u'
,
rhs_target
=
'v'
):
if
op
==
'sub'
:
op
=
'add'
rhs_data
=
-
rhs_data
if
op
==
'div'
:
op
=
'mul'
rhs_data
=
1.
/
rhs_data
return
GSDDMM
.
apply
(
gidx
,
op
,
lhs_data
,
rhs_data
,
lhs_target
,
rhs_target
)
def
gspmm_hetero
(
g
,
op
,
reduce_op
,
*
lhs_and_rhs_tuple
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment