Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
76af2a2e
Unverified
Commit
76af2a2e
authored
Aug 17, 2021
by
Zihao Ye
Committed by
GitHub
Aug 17, 2021
Browse files
[perf] Remove activation cache if not required. (#3258)
* upd * fix * upd
parent
ac01e880
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
130 additions
and
55 deletions
+130
-55
python/dgl/backend/pytorch/sparse.py
python/dgl/backend/pytorch/sparse.py
+130
-55
No files found.
python/dgl/backend/pytorch/sparse.py
View file @
76af2a2e
...
@@ -66,21 +66,54 @@ def _need_reduce_last_dim(ufeat, efeat):
...
@@ -66,21 +66,54 @@ def _need_reduce_last_dim(ufeat, efeat):
"""Indicates whether to reduce the last dimension on edges
"""Indicates whether to reduce the last dimension on edges
in the backward pass of spmm,
in the backward pass of spmm,
if so, use dot instead of mul."""
if so, use dot instead of mul."""
if
ufeat
is
None
or
efeat
is
None
:
return
False
ushp
=
ufeat
.
shape
ushp
=
ufeat
.
shape
eshp
=
efeat
.
shape
eshp
=
efeat
.
shape
return
ushp
[
1
:
-
1
]
==
eshp
[
1
:
-
1
]
and
eshp
[
-
1
]
==
1
and
ushp
[
-
1
]
>
1
return
ushp
[
1
:
-
1
]
==
eshp
[
1
:
-
1
]
and
eshp
[
-
1
]
==
1
and
ushp
[
-
1
]
>
1
def
_
muldiv
(
op
,
x
):
def
_
expand
(
x
,
shape
):
return
1.
/
x
if
op
==
'div'
else
x
return
x
.
expand
(
-
1
,
*
shape
)
def
_addsub
(
op
,
x
):
def
spmm_cache_X
(
binary_op
,
reduce_op
,
req_grad_X
,
req_grad_Y
):
return
-
x
if
op
==
'sub'
else
x
"""Rules to identify whether to cache X in SpMM forward stage."""
if
binary_op
!=
'copy_lhs'
and
req_grad_Y
:
if
reduce_op
==
'sum'
:
return
True
else
:
if
binary_op
==
'mul'
:
return
True
return
False
def
_expand
(
x
,
shape
):
def
spmm_cache_Y
(
binary_op
,
reduce_op
,
req_grad_X
,
req_grad_Y
):
return
x
.
expand
(
-
1
,
*
shape
)
"""Rules to identify whether to cache Y in SpMM forward stage."""
if
binary_op
!=
'copy_rhs'
and
req_grad_X
:
if
reduce_op
==
'sum'
:
if
binary_op
in
[
'mul'
,
'add'
]:
return
True
else
:
if
binary_op
==
'mul'
:
return
True
return
False
def
spmm_cache_argX
(
binary_op
,
reduce_op
,
req_grad_X
,
req_grad_Y
):
"""Rules to identify whether to cache argX in SpMM forward stage."""
if
req_grad_X
:
if
reduce_op
in
[
'min'
,
'max'
]:
return
True
return
False
def
spmm_cache_argY
(
binary_op
,
reduce_op
,
req_grad_X
,
req_grad_Y
):
"""Rules to identify whether to cache argY in SpMM forward stage."""
if
req_grad_Y
:
if
reduce_op
in
[
'min'
,
'max'
]:
return
True
return
False
class
GSpMM
(
th
.
autograd
.
Function
):
class
GSpMM
(
th
.
autograd
.
Function
):
...
@@ -88,58 +121,69 @@ class GSpMM(th.autograd.Function):
...
@@ -88,58 +121,69 @@ class GSpMM(th.autograd.Function):
@
custom_fwd
(
cast_inputs
=
th
.
float16
)
@
custom_fwd
(
cast_inputs
=
th
.
float16
)
def
forward
(
ctx
,
gidx
,
op
,
reduce_op
,
X
,
Y
):
def
forward
(
ctx
,
gidx
,
op
,
reduce_op
,
X
,
Y
):
out
,
(
argX
,
argY
)
=
_gspmm
(
gidx
,
op
,
reduce_op
,
X
,
Y
)
out
,
(
argX
,
argY
)
=
_gspmm
(
gidx
,
op
,
reduce_op
,
X
,
Y
)
ctx
.
backward_cache
=
gidx
,
op
,
reduce_op
reduce_last
=
_need_reduce_last_dim
(
X
,
Y
)
X_shape
=
X
.
shape
if
X
is
not
None
else
None
Y_shape
=
Y
.
shape
if
Y
is
not
None
else
None
dtype
=
X
.
dtype
if
X
is
not
None
else
Y
.
dtype
device
=
X
.
device
if
X
is
not
None
else
Y
.
device
ctx
.
backward_cache
=
gidx
,
op
,
reduce_op
,
X_shape
,
Y_shape
,
dtype
,
device
,
reduce_last
req_grad_X
=
X
.
requires_grad
if
X
is
not
None
else
False
req_grad_Y
=
Y
.
requires_grad
if
Y
is
not
None
else
False
if
not
spmm_cache_X
(
op
,
reduce_op
,
req_grad_X
,
req_grad_Y
):
X
=
None
if
not
spmm_cache_Y
(
op
,
reduce_op
,
req_grad_X
,
req_grad_Y
):
Y
=
None
if
not
spmm_cache_argX
(
op
,
reduce_op
,
req_grad_X
,
req_grad_Y
):
argX
=
None
if
not
spmm_cache_argY
(
op
,
reduce_op
,
req_grad_X
,
req_grad_Y
):
argY
=
None
ctx
.
save_for_backward
(
X
,
Y
,
argX
,
argY
)
ctx
.
save_for_backward
(
X
,
Y
,
argX
,
argY
)
return
out
return
out
@
staticmethod
@
staticmethod
@
custom_bwd
@
custom_bwd
def
backward
(
ctx
,
dZ
):
def
backward
(
ctx
,
dZ
):
gidx
,
op
,
reduce_op
=
ctx
.
backward_cache
gidx
,
op
,
reduce_op
,
X_shape
,
Y_shape
,
dtype
,
device
,
reduce_last
=
ctx
.
backward_cache
X
,
Y
,
argX
,
argY
=
ctx
.
saved_tensors
X
,
Y
,
argX
,
argY
=
ctx
.
saved_tensors
if
op
!=
'copy_rhs'
and
ctx
.
needs_input_grad
[
3
]:
if
op
!=
'copy_rhs'
and
ctx
.
needs_input_grad
[
3
]:
g_rev
=
gidx
.
reverse
()
g_rev
=
gidx
.
reverse
()
if
reduce_op
==
'sum'
:
if
reduce_op
==
'sum'
:
if
op
in
[
'mul'
,
'div'
]
:
if
op
==
'mul'
:
dX
=
gspmm
(
g_rev
,
'mul'
,
'sum'
,
dZ
,
_muldiv
(
op
,
Y
)
)
dX
=
gspmm
(
g_rev
,
'mul'
,
'sum'
,
dZ
,
Y
)
elif
op
in
[
'add'
,
'sub'
]
:
elif
op
==
'add'
:
dX
=
gspmm
(
g_rev
,
'copy_lhs'
,
'sum'
,
dZ
,
Y
)
dX
=
gspmm
(
g_rev
,
'copy_lhs'
,
'sum'
,
dZ
,
Y
)
elif
op
==
'copy_lhs'
:
elif
op
==
'copy_lhs'
:
dX
=
gspmm
(
g_rev
,
'copy_lhs'
,
'sum'
,
dZ
,
None
)
dX
=
gspmm
(
g_rev
,
'copy_lhs'
,
'sum'
,
dZ
,
None
)
else
:
# max/min
else
:
# max/min
dX
=
th
.
zeros
((
X
.
shape
[
0
],)
+
dZ
.
shape
[
1
:],
dX
=
th
.
zeros
((
X
_
shape
[
0
],)
+
dZ
.
shape
[
1
:],
dtype
=
X
.
dtype
,
device
=
X
.
device
)
dtype
=
dtype
,
device
=
device
)
if
op
in
[
'mul'
,
'div'
]
:
if
op
==
'mul'
:
grad
=
_muldiv
(
op
,
_expand
(
Y
,
dZ
.
shape
[
1
:]).
gather
(
grad
=
_expand
(
Y
,
dZ
.
shape
[
1
:]).
gather
(
0
,
argY
.
long
())
)
*
dZ
0
,
argY
.
long
())
*
dZ
dX
.
scatter_add_
(
0
,
argX
.
long
(),
grad
)
dX
.
scatter_add_
(
0
,
argX
.
long
(),
grad
)
elif
op
in
[
'add'
,
'sub'
,
'copy_lhs'
]:
elif
op
in
[
'add'
,
'copy_lhs'
]:
dX
.
scatter_add_
(
0
,
argX
.
long
(),
dZ
)
dX
.
scatter_add_
(
0
,
argX
.
long
(),
dZ
)
dX
=
_reduce_grad
(
dX
,
X
.
shape
)
dX
=
_reduce_grad
(
dX
,
X
_
shape
)
else
:
# X has not gradient
else
:
# X has not gradient
dX
=
None
dX
=
None
if
op
!=
'copy_lhs'
and
ctx
.
needs_input_grad
[
4
]:
if
op
!=
'copy_lhs'
and
ctx
.
needs_input_grad
[
4
]:
if
reduce_op
==
'sum'
:
if
reduce_op
==
'sum'
:
if
op
==
'mul'
and
_need_
reduce_last
_dim
(
X
,
Y
)
:
if
op
==
'mul'
and
reduce_last
:
dY
=
gsddmm
(
gidx
,
'dot'
,
X
,
dZ
)
dY
=
gsddmm
(
gidx
,
'dot'
,
X
,
dZ
)
elif
op
in
[
'mul'
,
'div'
]
:
elif
op
==
'mul'
:
dY
=
gsddmm
(
gidx
,
'mul'
,
X
,
dZ
)
dY
=
gsddmm
(
gidx
,
'mul'
,
X
,
dZ
)
if
op
==
'div'
:
elif
op
in
[
'add'
,
'copy_rhs'
]:
dY
=
-
dY
/
(
Y
**
2
)
dY
=
gsddmm
(
gidx
,
'copy_rhs'
,
X
,
dZ
)
elif
op
in
[
'add'
,
'sub'
,
'copy_rhs'
]:
dY
=
gsddmm
(
gidx
,
'copy_rhs'
,
X
,
_addsub
(
op
,
dZ
))
else
:
# max/min
else
:
# max/min
dY
=
th
.
zeros
((
Y
.
shape
[
0
],)
+
dZ
.
shape
[
1
:],
dY
=
th
.
zeros
((
Y
_
shape
[
0
],)
+
dZ
.
shape
[
1
:],
dtype
=
Y
.
dtype
,
device
=
Y
.
device
)
dtype
=
dtype
,
device
=
device
)
if
op
in
[
'mul'
,
'div'
]
:
if
op
==
'mul'
:
grad
=
_expand
(
X
,
dZ
.
shape
[
1
:]).
gather
(
grad
=
_expand
(
X
,
dZ
.
shape
[
1
:]).
gather
(
0
,
argX
.
long
())
*
dZ
0
,
argX
.
long
())
*
dZ
dY
.
scatter_add_
(
0
,
argY
.
long
(),
grad
)
dY
.
scatter_add_
(
0
,
argY
.
long
(),
grad
)
if
op
==
'div'
:
elif
op
in
[
'add'
,
'copy_rhs'
]:
dY
=
-
dY
/
(
Y
**
2
)
dY
.
scatter_add_
(
0
,
argY
.
long
(),
dZ
)
elif
op
in
[
'add'
,
'sub'
,
'copy_rhs'
]:
dY
=
_reduce_grad
(
dY
,
Y_shape
)
dY
.
scatter_add_
(
0
,
argY
.
long
(),
_addsub
(
op
,
dZ
))
dY
=
_reduce_grad
(
dY
,
Y
.
shape
)
else
:
# Y has no gradient
else
:
# Y has no gradient
dY
=
None
dY
=
None
return
None
,
None
,
None
,
dX
,
dY
return
None
,
None
,
None
,
dX
,
dY
...
@@ -178,7 +222,7 @@ class GSpMM_hetero(th.autograd.Function):
...
@@ -178,7 +222,7 @@ class GSpMM_hetero(th.autograd.Function):
# TODO(Israt): implement other combinations of message and reduce functions
# TODO(Israt): implement other combinations of message and reduce functions
if
reduce_op
==
'sum'
:
if
reduce_op
==
'sum'
:
if
op
in
[
'copy_rhs'
]:
if
op
in
[
'copy_rhs'
]:
tmp_Z
=
tuple
([
_addsub
(
op
,
dZ
[
i
]
)
if
dZ
[
i
]
is
not
None
else
None
tmp_Z
=
tuple
([
dZ
[
i
]
if
dZ
[
i
]
is
not
None
else
None
for
i
in
range
(
len
(
dZ
))])
for
i
in
range
(
len
(
dZ
))])
tmp
=
tuple
(
X
+
tmp_Z
)
tmp
=
tuple
(
X
+
tmp_Z
)
dY
=
gsddmm_hetero
(
g
,
'copy_rhs'
,
'u'
,
'v'
,
*
tmp
)
dY
=
gsddmm_hetero
(
g
,
'copy_rhs'
,
'u'
,
'v'
,
*
tmp
)
...
@@ -188,62 +232,81 @@ class GSpMM_hetero(th.autograd.Function):
...
@@ -188,62 +232,81 @@ class GSpMM_hetero(th.autograd.Function):
dY
=
tuple
([
None
]
*
len
(
Y
))
dY
=
tuple
([
None
]
*
len
(
Y
))
return
(
None
,
None
,
None
)
+
dX
+
dY
return
(
None
,
None
,
None
)
+
dX
+
dY
def
sddmm_cache_X
(
op
,
req_grad_X
,
req_grad_Y
):
"""Rules to identify whether to cache X in SDDMM forward stage."""
if
op
in
[
'mul'
,
'dot'
]
and
req_grad_Y
:
return
True
return
False
def
sddmm_cache_Y
(
op
,
req_grad_X
,
req_grad_Y
):
"""Rules to identify whether to cache Y in SDDMM forward stage."""
if
op
in
[
'mul'
,
'dot'
]
and
req_grad_X
:
return
True
return
False
class
GSDDMM
(
th
.
autograd
.
Function
):
class
GSDDMM
(
th
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
@
custom_fwd
(
cast_inputs
=
th
.
float16
)
@
custom_fwd
(
cast_inputs
=
th
.
float16
)
def
forward
(
ctx
,
gidx
,
op
,
X
,
Y
,
lhs_target
,
rhs_target
):
def
forward
(
ctx
,
gidx
,
op
,
X
,
Y
,
lhs_target
,
rhs_target
):
out
=
_gsddmm
(
gidx
,
op
,
X
,
Y
,
lhs_target
,
rhs_target
)
out
=
_gsddmm
(
gidx
,
op
,
X
,
Y
,
lhs_target
,
rhs_target
)
ctx
.
backward_cache
=
gidx
,
op
,
lhs_target
,
rhs_target
X_shape
=
X
.
shape
if
X
is
not
None
else
None
Y_shape
=
Y
.
shape
if
Y
is
not
None
else
None
ctx
.
backward_cache
=
gidx
,
op
,
lhs_target
,
rhs_target
,
X_shape
,
Y_shape
req_grad_X
=
X
.
requires_grad
if
X
is
not
None
else
False
req_grad_Y
=
Y
.
requires_grad
if
Y
is
not
None
else
False
if
not
sddmm_cache_X
(
op
,
req_grad_X
,
req_grad_Y
):
X
=
None
if
not
sddmm_cache_Y
(
op
,
req_grad_X
,
req_grad_Y
):
Y
=
None
ctx
.
save_for_backward
(
X
,
Y
)
ctx
.
save_for_backward
(
X
,
Y
)
return
out
return
out
@
staticmethod
@
staticmethod
@
custom_bwd
@
custom_bwd
def
backward
(
ctx
,
dZ
):
def
backward
(
ctx
,
dZ
):
gidx
,
op
,
lhs_target
,
rhs_target
=
ctx
.
backward_cache
gidx
,
op
,
lhs_target
,
rhs_target
,
X_shape
,
Y_shape
=
ctx
.
backward_cache
X
,
Y
=
ctx
.
saved_tensors
X
,
Y
=
ctx
.
saved_tensors
if
op
!=
'copy_rhs'
and
ctx
.
needs_input_grad
[
2
]:
if
op
!=
'copy_rhs'
and
ctx
.
needs_input_grad
[
2
]:
if
lhs_target
in
[
'u'
,
'v'
]:
if
lhs_target
in
[
'u'
,
'v'
]:
_gidx
=
gidx
if
lhs_target
==
'v'
else
gidx
.
reverse
()
_gidx
=
gidx
if
lhs_target
==
'v'
else
gidx
.
reverse
()
if
op
in
[
'add'
,
'sub'
,
'copy_lhs'
]:
if
op
in
[
'add'
,
'copy_lhs'
]:
dX
=
gspmm
(
_gidx
,
'copy_rhs'
,
'sum'
,
None
,
dZ
)
dX
=
gspmm
(
_gidx
,
'copy_rhs'
,
'sum'
,
None
,
dZ
)
else
:
# mul,
div,
dot
else
:
# mul, dot
if
rhs_target
==
lhs_target
:
if
rhs_target
==
lhs_target
:
dX
=
gspmm
(
_gidx
,
'copy_rhs'
,
'sum'
,
None
,
dZ
)
*
_muldiv
(
op
,
Y
)
dX
=
gspmm
(
_gidx
,
'copy_rhs'
,
'sum'
,
None
,
dZ
)
*
Y
elif
rhs_target
==
'e'
:
elif
rhs_target
==
'e'
:
dX
=
gspmm
(
_gidx
,
'copy_rhs'
,
'sum'
,
None
,
dZ
*
_muldiv
(
op
,
Y
)
)
dX
=
gspmm
(
_gidx
,
'copy_rhs'
,
'sum'
,
None
,
dZ
*
Y
)
else
:
# rhs_target = !lhs_target
else
:
# rhs_target = !lhs_target
dX
=
gspmm
(
_gidx
,
'mul'
,
'sum'
,
_muldiv
(
op
,
Y
)
,
dZ
)
dX
=
gspmm
(
_gidx
,
'mul'
,
'sum'
,
Y
,
dZ
)
else
:
# lhs_target == 'e'
else
:
# lhs_target == 'e'
if
op
in
[
'add'
,
'sub'
,
'copy_lhs'
]:
if
op
in
[
'add'
,
'copy_lhs'
]:
dX
=
dZ
dX
=
dZ
else
:
# mul,
div,
dot
else
:
# mul, dot
dX
=
gsddmm
(
gidx
,
'mul'
,
dZ
,
_muldiv
(
op
,
Y
)
,
'e'
,
rhs_target
)
dX
=
gsddmm
(
gidx
,
'mul'
,
dZ
,
Y
,
'e'
,
rhs_target
)
dX
=
_reduce_grad
(
dX
,
X
.
shape
)
dX
=
_reduce_grad
(
dX
,
X
_
shape
)
else
:
else
:
dX
=
None
dX
=
None
if
op
!=
'copy_lhs'
and
ctx
.
needs_input_grad
[
3
]:
if
op
!=
'copy_lhs'
and
ctx
.
needs_input_grad
[
3
]:
if
rhs_target
in
[
'u'
,
'v'
]:
if
rhs_target
in
[
'u'
,
'v'
]:
_gidx
=
gidx
if
rhs_target
==
'v'
else
gidx
.
reverse
()
_gidx
=
gidx
if
rhs_target
==
'v'
else
gidx
.
reverse
()
if
op
in
[
'add'
,
'sub'
,
'copy_rhs'
]:
if
op
in
[
'add'
,
'copy_rhs'
]:
dY
=
gspmm
(
_gidx
,
'copy_rhs'
,
'sum'
,
None
,
_addsub
(
op
,
dZ
)
)
dY
=
gspmm
(
_gidx
,
'copy_rhs'
,
'sum'
,
None
,
dZ
)
else
:
# mul,
div,
dot
else
:
# mul, dot
if
lhs_target
==
rhs_target
:
if
lhs_target
==
rhs_target
:
dY
=
gspmm
(
_gidx
,
'copy_rhs'
,
'sum'
,
None
,
dZ
)
*
X
dY
=
gspmm
(
_gidx
,
'copy_rhs'
,
'sum'
,
None
,
dZ
)
*
X
elif
lhs_target
==
'e'
:
elif
lhs_target
==
'e'
:
dY
=
gspmm
(
_gidx
,
'copy_rhs'
,
'sum'
,
None
,
dZ
*
X
)
dY
=
gspmm
(
_gidx
,
'copy_rhs'
,
'sum'
,
None
,
dZ
*
X
)
else
:
# rhs_target = !lhs_target
else
:
# rhs_target = !lhs_target
dY
=
gspmm
(
_gidx
,
'mul'
,
'sum'
,
X
,
dZ
)
dY
=
gspmm
(
_gidx
,
'mul'
,
'sum'
,
X
,
dZ
)
if
op
==
'div'
:
dY
=
-
dY
/
(
Y
**
2
)
else
:
else
:
if
op
in
[
'add'
,
'sub'
,
'copy_rhs'
]:
if
op
in
[
'add'
,
'copy_rhs'
]:
dY
=
_addsub
(
op
,
dZ
)
dY
=
dZ
else
:
# mul,
div,
dot
else
:
# mul, dot
dY
=
gsddmm
(
gidx
,
'mul'
,
dZ
,
X
,
'e'
,
lhs_target
)
dY
=
gsddmm
(
gidx
,
'mul'
,
dZ
,
X
,
'e'
,
lhs_target
)
if
op
==
'div'
:
dY
=
_reduce_grad
(
dY
,
Y_shape
)
dY
=
-
dY
/
(
Y
**
2
)
dY
=
_reduce_grad
(
dY
,
Y
.
shape
)
else
:
else
:
dY
=
None
dY
=
None
return
None
,
None
,
dX
,
dY
,
None
,
None
return
None
,
None
,
dX
,
dY
,
None
,
None
...
@@ -422,9 +485,21 @@ class CSRMask(th.autograd.Function):
...
@@ -422,9 +485,21 @@ class CSRMask(th.autograd.Function):
def
gspmm
(
gidx
,
op
,
reduce_op
,
lhs_data
,
rhs_data
):
def
gspmm
(
gidx
,
op
,
reduce_op
,
lhs_data
,
rhs_data
):
if
op
==
'sub'
:
op
=
'add'
rhs_data
=
-
rhs_data
if
op
==
'div'
:
op
=
'mul'
rhs_data
=
1.
/
rhs_data
return
GSpMM
.
apply
(
gidx
,
op
,
reduce_op
,
lhs_data
,
rhs_data
)
return
GSpMM
.
apply
(
gidx
,
op
,
reduce_op
,
lhs_data
,
rhs_data
)
def
gsddmm
(
gidx
,
op
,
lhs_data
,
rhs_data
,
lhs_target
=
'u'
,
rhs_target
=
'v'
):
def
gsddmm
(
gidx
,
op
,
lhs_data
,
rhs_data
,
lhs_target
=
'u'
,
rhs_target
=
'v'
):
if
op
==
'sub'
:
op
=
'add'
rhs_data
=
-
rhs_data
if
op
==
'div'
:
op
=
'mul'
rhs_data
=
1.
/
rhs_data
return
GSDDMM
.
apply
(
gidx
,
op
,
lhs_data
,
rhs_data
,
lhs_target
,
rhs_target
)
return
GSDDMM
.
apply
(
gidx
,
op
,
lhs_data
,
rhs_data
,
lhs_target
,
rhs_target
)
def
gspmm_hetero
(
g
,
op
,
reduce_op
,
*
lhs_and_rhs_tuple
):
def
gspmm_hetero
(
g
,
op
,
reduce_op
,
*
lhs_and_rhs_tuple
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment