Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
0ae0e784
Commit
0ae0e784
authored
Jul 28, 2018
by
rusty1s
Browse files
backward implementation
parent
572227be
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
48 additions
and
23 deletions
+48
-23
test/test_matmul.py
test/test_matmul.py
+18
-0
torch_sparse/matmul.py
torch_sparse/matmul.py
+30
-23
No files found.
test/test_matmul.py
View file @
0ae0e784
import
torch
from
torch_sparse
import
spspmm
def
test_spspmm
():
e1
=
torch
.
tensor
([[
0
,
0
,
1
,
2
,
2
],
[
1
,
2
,
0
,
0
,
1
]])
v1
=
torch
.
tensor
([
1
,
2
,
3
,
4
,
5
],
dtype
=
torch
.
float
,
requires_grad
=
True
)
matrix1
=
(
e1
,
v1
,
torch
.
Size
([
3
,
3
]))
e2
=
torch
.
tensor
([[
0
,
2
],
[
1
,
0
]])
v2
=
torch
.
tensor
([
2
,
4
],
dtype
=
torch
.
float
,
requires_grad
=
True
)
matrix2
=
(
e2
,
v2
,
torch
.
Size
([
3
,
2
]))
index
,
value
=
spspmm
(
*
matrix1
,
*
matrix2
)
out
=
torch
.
sparse
.
FloatTensor
(
index
,
value
,
torch
.
Size
([
3
,
2
])).
to_dense
()
assert
out
.
tolist
()
==
[[
8
,
0
],
[
0
,
6
],
[
0
,
8
]]
value
.
sum
().
backward
()
torch_sparse/matmul.py
View file @
0ae0e784
...
@@ -5,42 +5,49 @@ from scipy.sparse import coo_matrix
...
@@ -5,42 +5,49 @@ from scipy.sparse import coo_matrix
class
SpSpMM
(
torch
.
autograd
.
Function
):
class
SpSpMM
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
matrix1
,
matrix2
):
def
forward
(
ctx
,
e1
,
v1
,
s1
,
e2
,
v2
,
s2
):
ctx
.
save_for_backawrd
(
matrix1
,
matrix2
)
e
,
v
=
mm
(
e1
,
v1
,
s1
,
e2
,
v2
,
s2
)
return
mm
(
matrix1
,
matrix2
)
ctx
.
s1
,
ctx
.
s2
=
s1
,
s2
ctx
.
save_for_backward
(
e1
,
v1
,
e2
,
v2
,
e
)
return
e
,
v
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
grad_out
):
def
backward
(
ctx
,
grad_e
,
grad_v
):
matrix1
,
matrix2
=
ctx
.
saved_variables
e1
,
v1
,
e2
,
v2
,
e
=
ctx
.
saved_variables
grad_matrix1
=
grad_matrix2
=
None
grad_v1
=
grad_v2
=
None
grad
=
(
e
,
grad_v
,
torch
.
Size
([
ctx
.
s1
[
0
],
ctx
.
s2
[
1
]]))
if
ctx
.
needs_input_grad
[
0
]:
if
ctx
.
needs_input_grad
[
1
]:
grad_matrix1
=
mm
(
grad_out
,
matrix2
.
t
())
e2
=
torch
.
stack
([
e2
[
1
],
e2
[
0
]],
dim
=
0
)
_
,
grad_v1
=
mm
(
*
grad
,
e2
,
v2
,
torch
.
Size
([
ctx
.
s2
[
1
],
ctx
.
s2
[
0
]]))
if
ctx
.
needs_input_grad
[
0
]:
if
ctx
.
needs_input_grad
[
4
]:
grad_matrix2
=
mm
(
matrix1
.
t
(),
grad_out
)
e1
=
torch
.
stack
([
e1
[
1
],
e1
[
0
]],
dim
=
0
)
_
,
grad_v2
=
mm
(
e1
,
v1
,
torch
.
Size
([
ctx
.
s1
[
1
],
ctx
.
s1
[
0
]]),
*
grad
)
return
grad_matrix1
,
grad_matrix2
return
None
,
grad_v1
,
None
,
None
,
grad_v2
,
None
spspmm
=
SpSpMM
.
apply
spspmm
=
SpSpMM
.
apply
def
mm
(
A
,
B
):
def
mm
(
e1
,
v1
,
s1
,
e2
,
v2
,
s2
):
if
A
[
0
]
.
is_cuda
:
if
e1
.
is_cuda
:
pass
pass
else
:
else
:
return
mm_cpu
(
A
,
B
)
return
mm_cpu
(
e1
,
v1
,
s1
,
e2
,
v2
,
s2
)
def
mm_cpu
(
A
,
B
):
def
mm_cpu
(
e1
,
v1
,
s1
,
e2
,
v2
,
s2
):
A
,
B
,
=
to_csr
(
A
),
to_csr
(
B
)
matrix1
,
matrix2
,
=
to_csr
(
e1
,
v1
,
s1
),
to_csr
(
e2
,
v2
,
s2
)
C
=
A
.
dot
(
B
).
tocoo
()
out
=
matrix1
.
dot
(
matrix2
).
tocoo
()
row
,
col
,
value
=
from_numpy
(
C
.
row
),
from_numpy
(
C
.
col
)
,
from_numpy
(
C
.
data
)
row
,
col
=
from_numpy
(
out
.
row
)
.
long
()
,
from_numpy
(
out
.
col
)
.
long
(
)
return
torch
.
stack
([
row
,
col
],
dim
=
0
),
value
return
torch
.
stack
([
row
,
col
],
dim
=
0
),
from_numpy
(
out
.
data
)
def
to_csr
(
A
):
def
to_csr
(
index
,
value
,
size
):
(
row
,
col
),
value
,
size
=
A
index
,
value
=
index
.
detach
().
numpy
(),
value
.
detach
().
numpy
()
row
,
col
,
value
=
row
.
numpy
(),
col
.
numpy
(),
value
.
numpy
(
)
shape
=
(
size
[
0
],
size
[
1
]
)
return
coo_matrix
((
value
,
(
row
,
col
)),
shape
=
(
size
[
0
],
size
[
1
])
).
tocsr
()
return
coo_matrix
((
value
,
(
index
[
0
],
index
[
1
])),
shape
).
tocsr
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment