Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
bd49e20a
Commit
bd49e20a
authored
Jan 22, 2020
by
rusty1s
Browse files
spmm backward implementation
parent
df5f7063
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
48 additions
and
80 deletions
+48
-80
cpu/spmm.cpp
cpu/spmm.cpp
+1
-1
test/test_matmul.py
test/test_matmul.py
+29
-67
torch_sparse/matmul.py
torch_sparse/matmul.py
+18
-12
No files found.
cpu/spmm.cpp
View file @
bd49e20a
...
@@ -109,7 +109,7 @@ spmm(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> value_opt,
...
@@ -109,7 +109,7 @@ spmm(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> value_opt,
at
::
optional
<
at
::
Tensor
>
arg_out
=
at
::
nullopt
;
at
::
optional
<
at
::
Tensor
>
arg_out
=
at
::
nullopt
;
int64_t
*
arg_out_data
=
nullptr
;
int64_t
*
arg_out_data
=
nullptr
;
if
(
reduce2REDUCE
.
at
(
reduce
)
==
MIN
||
reduce2REDUCE
.
at
(
reduce
)
==
MAX
)
{
if
(
reduce2REDUCE
.
at
(
reduce
)
==
MIN
||
reduce2REDUCE
.
at
(
reduce
)
==
MAX
)
{
arg_out
=
at
::
full_like
(
out
,
-
1
,
rowptr
.
options
());
arg_out
=
at
::
full_like
(
out
,
col
.
numel
()
,
rowptr
.
options
());
arg_out_data
=
arg_out
.
value
().
DATA_PTR
<
int64_t
>
();
arg_out_data
=
arg_out
.
value
().
DATA_PTR
<
int64_t
>
();
}
}
...
...
test/test_matmul.py
View file @
bd49e20a
...
@@ -2,91 +2,53 @@ from itertools import product
...
@@ -2,91 +2,53 @@ from itertools import product
import
pytest
import
pytest
import
torch
import
torch
from
torch.autograd
import
gradcheck
from
torch_sparse.matmul
import
matmul
from
torch_sparse.matmul
import
matmul
from
torch_sparse.tensor
import
SparseTensor
from
torch_sparse.tensor
import
SparseTensor
import
torch_scatter
import
torch_scatter
from
.utils
import
tensor
,
devices
,
dtypes
from
.utils
import
devices
,
grad_
dtypes
devices
=
[
'cpu'
]
devices
=
[
'cpu'
]
dtypes
=
[
torch
.
float
]
grad_
dtypes
=
[
torch
.
float
]
reductions
=
[
'sum'
,
'mean'
,
'min'
,
'max'
]
reductions
=
[
'sum'
,
'mean'
,
'min'
,
'max'
]
# grad_reductions = ['sum', 'mean']
reductions
=
[
'min'
]
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
dtypes
,
devices
))
def
test_spmm_forward
(
dtype
,
device
):
src_dense
=
torch
.
randn
((
5
,
4
),
dtype
=
dtype
,
device
=
device
)
src
=
SparseTensor
.
from_dense
(
src_dense
)
src
.
requires_grad_
()
src_dense
=
src_dense
.
clone
().
requires_grad_
()
other
=
torch
.
randn
((
4
,
8
),
dtype
=
dtype
,
device
=
device
)
other
.
requires_grad_
()
out1
=
matmul
(
src
,
other
)
grad_out
=
torch
.
randn_like
(
out1
)
out1
.
backward
(
grad_out
)
other
.
grad
=
None
out2
=
torch
.
matmul
(
src_dense
,
other
)
out2
.
backward
(
grad_out
)
# assert torch.allclose(out1, out2)
# assert torch.allclose(src.storage.value.grad.view(5, 4), src_dense.grad)
@
pytest
.
mark
.
parametrize
(
'dtype,device,reduce'
,
@
pytest
.
mark
.
parametrize
(
'dtype,device,reduce'
,
product
(
dtypes
,
devices
,
reductions
))
product
(
grad_
dtypes
,
devices
,
reductions
))
def
test_spmm
(
dtype
,
device
,
reduce
):
def
test_spmm
(
dtype
,
device
,
reduce
):
src
=
torch
.
ones
((
5
,
4
),
dtype
=
dtype
,
device
=
device
)
src
=
torch
.
randn
((
10
,
8
),
dtype
=
dtype
,
device
=
device
)
src
[
2
,
:]
=
0
# Delete one row...
src
[
2
]
=
0
src
[
:,
2
:
4
]
=
0
# Delete one col...
src
=
SparseTensor
.
from_dense
(
src
).
requires_grad_
()
src
=
SparseTensor
.
from_dense
(
src
).
requires_grad_
()
src
.
set_value_
(
None
)
(
row
,
col
),
value
=
src
.
coo
(
)
other
=
torch
.
randn
((
2
,
4
,
2
),
dtype
=
dtype
,
device
=
device
,
other
=
torch
.
randn
((
2
,
8
,
2
),
dtype
=
dtype
,
device
=
device
,
requires_grad
=
True
)
requires_grad
=
True
)
(
row
,
col
),
value
=
src
.
coo
()
src_col
=
other
.
index_select
(
-
2
,
col
)
*
value
.
unsqueeze
(
-
1
)
out1
=
other
.
index_select
(
-
2
,
col
)
# * value.unsqueeze(-1)
func
=
'add'
if
reduce
==
'sum'
else
reduce
func
=
'add'
if
reduce
==
'sum'
else
reduce
out1
=
getattr
(
torch_scatter
,
f
'scatter_
{
func
}
'
)(
out1
,
row
,
dim
=-
2
)
expected
=
getattr
(
torch_scatter
,
f
'scatter_
{
func
}
'
)(
src_col
,
row
,
dim
=-
2
)
out1
=
out1
[
0
]
if
isinstance
(
out1
,
tuple
)
else
out1
expected
=
expected
[
0
]
if
isinstance
(
expected
,
tuple
)
else
expected
if
reduce
==
'min'
:
grad_out
=
torch
.
randn_like
(
out1
)
expected
[
expected
>
1000
]
=
0
out1
.
backward
(
grad_out
)
if
reduce
==
'max'
:
# grad_value1 = value.grad
expected
[
expected
<
1000
]
=
0
# value.grad = None
grad_other1
=
other
.
grad
grad_out
=
torch
.
randn_like
(
expected
)
expected
.
backward
(
grad_out
)
expected_grad_value
=
value
.
grad
value
.
grad
=
None
expected_grad_other
=
other
.
grad
other
.
grad
=
None
other
.
grad
=
None
print
(
reduce
)
out
=
matmul
(
src
,
other
,
reduce
)
out2
=
matmul
(
src
,
other
,
reduce
)
out
=
out
[
0
]
if
isinstance
(
out
,
tuple
)
else
out
out2
=
out2
[
0
]
if
isinstance
(
out2
,
tuple
)
else
out2
out
.
backward
(
grad_out
)
out2
.
backward
(
grad_out
)
# grad_value2 = value.grad
# value.grad = None
grad_other2
=
other
.
grad
other
.
grad
=
None
# assert torch.allclose(out1, out2)
# assert torch.allclose(grad_value1, grad_value2)
assert
torch
.
allclose
(
grad_other1
,
grad_other2
)
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
dtypes
,
devices
))
def
test_spmm_backward
(
dtype
,
device
):
src_dense
=
torch
.
randn
((
5
,
4
),
dtype
=
torch
.
double
,
device
=
device
)
src
=
SparseTensor
.
from_dense
(
src_dense
)
src
.
requires_grad_
()
other
=
torch
.
randn
((
4
,
8
),
dtype
=
torch
.
double
,
device
=
device
)
other
.
requires_grad_
()
# assert gradcheck(matmul, (src, other, "sum"))
assert
torch
.
allclose
(
expected
,
out
)
assert
torch
.
allclose
(
expected_grad_value
,
value
.
grad
)
assert
torch
.
allclose
(
expected_grad_other
,
other
.
grad
)
torch_sparse/matmul.py
View file @
bd49e20a
...
@@ -34,6 +34,12 @@ class SPMM(torch.autograd.Function):
...
@@ -34,6 +34,12 @@ class SPMM(torch.autograd.Function):
data
=
ctx
.
saved_tensors
data
=
ctx
.
saved_tensors
index
,
rowcount
,
rowptr
,
colptr
,
csr2csc
,
value
,
mat
,
arg_out
=
data
index
,
rowcount
,
rowptr
,
colptr
,
csr2csc
,
value
,
mat
,
arg_out
=
data
invalid_arg_mask
=
arg_out_ind
=
None
if
ctx
.
reduce
in
[
'min'
,
'max'
]
and
(
ctx
.
needs_input_grad
[
5
]
or
ctx
.
needs_input_grad
[
6
]):
invalid_arg_mask
=
arg_out
==
index
.
size
(
1
)
arg_out_ind
=
arg_out
.
masked_fill
(
invalid_arg_mask
,
-
1
)
grad_value
=
None
grad_value
=
None
if
ctx
.
needs_input_grad
[
5
]:
if
ctx
.
needs_input_grad
[
5
]:
if
ctx
.
reduce
in
[
'sum'
,
'add'
]:
if
ctx
.
reduce
in
[
'sum'
,
'add'
]:
...
@@ -45,12 +51,12 @@ class SPMM(torch.autograd.Function):
...
@@ -45,12 +51,12 @@ class SPMM(torch.autograd.Function):
rowptr
,
index
[
1
],
mat
,
grad_out
,
ctx
.
reduce
)
rowptr
,
index
[
1
],
mat
,
grad_out
,
ctx
.
reduce
)
elif
ctx
.
reduce
in
[
'min'
,
'max'
]:
elif
ctx
.
reduce
in
[
'min'
,
'max'
]:
col
=
index
[
1
][
arg_out
.
flatten
()].
view_as
(
arg_out
)
col
=
index
[
1
][
arg_out
_ind
.
flatten
()].
view_as
(
arg_out
)
out
=
mat
.
gather
(
-
2
,
col
).
mul_
(
grad_out
)
out
=
mat
.
gather
(
-
2
,
col
).
mul_
(
grad_out
)
out
.
masked_fill_
(
arg_out
==
-
1
,
0
)
out
.
masked_fill_
(
invalid_arg_mask
,
0
)
col
=
col
.
add_
(
rowptr
[:
-
1
].
view
(
-
1
,
1
))
grad_value
=
scatter_add
(
out
.
flatten
(),
arg_out
.
flatten
(),
grad_value
=
scatter_add
(
out
.
flatten
(),
col
.
flatten
(),
dim
=
0
,
dim
=
0
,
dim_size
=
value
.
numel
()
+
1
)
dim_size
=
value
.
numel
())
grad_value
=
grad_value
[:
-
1
]
grad_mat
=
None
grad_mat
=
None
if
ctx
.
needs_input_grad
[
6
]:
if
ctx
.
needs_input_grad
[
6
]:
...
@@ -70,12 +76,12 @@ class SPMM(torch.autograd.Function):
...
@@ -70,12 +76,12 @@ class SPMM(torch.autograd.Function):
elif
ctx
.
reduce
in
[
'min'
,
'max'
]:
elif
ctx
.
reduce
in
[
'min'
,
'max'
]:
if
value
is
not
None
:
if
value
is
not
None
:
value
=
value
[
arg_out
.
flatten
()].
view_as
(
arg_out
)
value
=
value
[
arg_out
_ind
.
flatten
()].
view_as
(
arg_out
)
value
=
value
.
mul_
(
grad_out
)
value
=
value
.
mul_
(
grad_out
)
else
:
else
:
value
=
grad_out
value
=
grad_out
value
.
masked_fill_
(
arg_out
==
-
1
,
0
)
value
.
masked_fill_
(
invalid_arg_mask
,
0
)
col
=
index
[
1
][
arg_out
.
flatten
()].
view_as
(
arg_out
)
col
=
index
[
1
][
arg_out
_ind
.
flatten
()].
view_as
(
arg_out
)
grad_mat
=
scatter_add
(
value
,
col
,
dim
=-
2
,
grad_mat
=
scatter_add
(
value
,
col
,
dim
=-
2
,
dim_size
=
mat
.
size
(
-
2
))
dim_size
=
mat
.
size
(
-
2
))
...
@@ -89,14 +95,14 @@ def matmul(src, other, reduce='sum'):
...
@@ -89,14 +95,14 @@ def matmul(src, other, reduce='sum'):
assert
reduce
in
[
'sum'
,
'add'
,
'mean'
,
'min'
,
'max'
]
assert
reduce
in
[
'sum'
,
'add'
,
'mean'
,
'min'
,
'max'
]
(
index
,
value
),
rowptr
=
src
.
coo
(),
src
.
storage
.
rowptr
(
index
,
value
),
rowptr
=
src
.
coo
(),
src
.
storage
.
rowptr
csr2csc
=
colptr
=
None
if
other
.
requires_grad
and
reduce
in
[
'sum'
,
'add'
,
'mean'
]:
csr2csc
,
colptr
=
src
.
storage
.
csr2csc
,
src
.
storage
.
colptr
rowcount
=
None
rowcount
=
None
if
other
.
requires_grad
and
reduce
in
[
'mean'
]:
if
other
.
requires_grad
and
reduce
in
[
'mean'
]:
rowcount
=
src
.
storage
.
rowcount
rowcount
=
src
.
storage
.
rowcount
csr2csc
=
colptr
=
None
if
other
.
requires_grad
and
reduce
in
[
'sum'
,
'add'
,
'mean'
]:
csr2csc
,
colptr
=
src
.
storage
.
csr2csc
,
src
.
storage
.
colptr
return
SPMM
.
apply
(
index
,
rowcount
,
rowptr
,
colptr
,
csr2csc
,
value
,
return
SPMM
.
apply
(
index
,
rowcount
,
rowptr
,
colptr
,
csr2csc
,
value
,
other
,
reduce
)
other
,
reduce
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment