Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
5788c855
Commit
5788c855
authored
Jul 30, 2018
by
rusty1s
Browse files
clean up code
parent
ef2c346f
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
68 additions
and
43 deletions
+68
-43
test/test_matmul.py
test/test_matmul.py
+8
-4
torch_sparse/__init__.py
torch_sparse/__init__.py
+2
-2
torch_sparse/matmul.py
torch_sparse/matmul.py
+51
-37
torch_sparse/transpose.py
torch_sparse/transpose.py
+7
-0
No files found.
test/test_matmul.py
View file @
5788c855
...
...
@@ -2,7 +2,7 @@ from itertools import product
import
pytest
import
torch
from
torch_sparse
import
spspmm
,
SparseTensor
from
torch_sparse
import
spspmm
from
.utils
import
dtypes
,
devices
,
tensor
...
...
@@ -17,9 +17,13 @@ def test_spspmm(dtype, device):
value
=
tensor
([
2
,
4
],
dtype
,
device
)
B
=
(
index
,
value
,
torch
.
Size
([
3
,
2
]))
index
,
value
=
spspmm
(
*
A
,
*
B
)
out
=
SparseTensor
(
index
,
value
,
torch
.
Size
([
3
,
2
]))
assert
out
.
to_dense
().
tolist
()
==
[[
8
,
0
],
[
0
,
6
],
[
0
,
8
]]
index
,
value
,
size
=
spspmm
(
*
A
,
*
B
)
print
(
index
)
print
(
value
)
print
(
size
)
# out = torch.sparse_coo_tensor(index, value, size)
# assert out.to_dense().tolist() == [[8, 0], [0, 6], [0, 8]]
# TODO TEST backward
# value.sum().backward()
torch_sparse/__init__.py
View file @
5788c855
from
.sparse
import
SparseTensor
from
.matmul
import
spspmm
from
.transpose
import
transpose
__all__
=
[
'SparseTensor'
,
'spspmm'
,
'transpose'
,
]
torch_sparse/matmul.py
View file @
5788c855
import
torch
from
torch
import
from_numpy
from
scipy.
sparse
import
coo_matrix
import
scipy.sparse
from
torch_
sparse
import
transpose
from
torch_sparse
import
SparseTensor
import
matmul_cuda
if
torch
.
cuda
.
is_available
():
import
matmul_cuda
def
spspmm
(
indexA
,
valueA
,
sizeA
,
indexB
,
valueB
,
sizeB
):
assert
valueA
.
dtype
==
valueB
.
dtype
assert
len
(
sizeA
)
==
len
(
sizeB
)
==
2
assert
sizeA
[
1
]
==
sizeB
[
0
]
index
,
value
=
SpSpMM
.
apply
(
indexA
,
valueA
,
sizeA
,
indexB
,
valueB
,
sizeB
)
size
=
torch
.
Size
([
sizeA
[
0
],
sizeB
[
1
]])
return
index
,
value
,
size
class
SpSpMM
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
e1
,
v1
,
s1
,
e2
,
v2
,
s2
):
e
,
v
=
mm
(
e1
,
v1
,
s1
,
e2
,
v2
,
s2
)
def
forward
(
ctx
,
indexA
,
valueA
,
sizeA
,
indexB
,
valueB
,
sizeB
):
index
,
value
=
mm
(
indexA
,
valueA
,
sizeA
,
indexB
,
valueB
,
sizeB
)
ctx
.
s
1
,
ctx
.
s
2
=
s1
,
s2
ctx
.
save_for_backward
(
e1
,
v1
,
e2
,
v2
,
e
)
ctx
.
s
izeA
,
ctx
.
s
izeB
=
sizeA
,
sizeB
ctx
.
save_for_backward
(
indexA
,
valueA
,
indexB
,
valueB
,
index
)
return
e
,
v
return
index
,
value
@
staticmethod
def
backward
(
ctx
,
grad_
e
,
grad_v
):
e1
,
v1
,
e2
,
v2
,
e
=
ctx
.
saved_variables
grad_v
1
=
grad_v
2
=
None
grad
=
(
e
,
grad_v
,
torch
.
Size
([
ctx
.
s
1
[
0
],
ctx
.
s
2
[
1
]]))
def
backward
(
ctx
,
grad_
index
,
grad_v
alue
):
indexA
,
valueA
,
indexB
,
valueB
,
index
=
ctx
.
saved_variables
grad_v
alueA
=
grad_v
alueB
=
None
grad
=
(
index
,
grad_v
alue
,
torch
.
Size
([
ctx
.
s
izeA
[
0
],
ctx
.
s
izeB
[
1
]]))
if
ctx
.
needs_input_grad
[
1
]:
e2
=
torch
.
stack
([
e2
[
1
],
e2
[
0
]],
dim
=
0
)
_
,
grad_v
1
=
mm
(
*
grad
,
e2
,
v2
,
torch
.
Size
([
ctx
.
s2
[
1
],
ctx
.
s2
[
0
]])
)
B_tranposed
=
transpose
(
indexB
,
valueB
,
ctx
.
sizeB
)
_
,
grad_v
alueA
=
mm
(
*
grad
,
*
B_tranposed
)
if
ctx
.
needs_input_grad
[
4
]:
e1
=
torch
.
stack
([
e1
[
1
],
e1
[
0
]],
dim
=
0
)
_
,
grad_v
2
=
mm
(
e1
,
v1
,
torch
.
Size
([
ctx
.
s1
[
1
],
ctx
.
s1
[
0
]])
,
*
grad
)
A_tranposed
=
transpose
(
indexA
,
valueA
,
ctx
.
sizeA
)
_
,
grad_v
alueB
=
mm
(
*
A_tranposed
,
*
grad
)
return
None
,
grad_v
1
,
None
,
None
,
grad_v
2
,
None
return
None
,
grad_v
alueA
,
None
,
None
,
grad_v
alueB
,
None
spspmm
=
SpSpMM
.
apply
def
mm
(
indexA
,
valueA
,
sizeA
,
indexB
,
valueB
,
sizeB
):
if
valueA
.
is_cuda
:
return
mm_cuda
(
indexA
,
valueA
,
sizeA
,
indexB
,
valueB
,
sizeB
)
else
:
return
mm_cpu
(
indexA
,
valueA
,
sizeA
,
indexB
,
valueB
,
sizeB
)
def
mm
(
e1
,
v1
,
s1
,
e2
,
v2
,
s2
):
if
v1
.
is_cuda
:
return
mm_cuda
(
e1
,
v1
,
s1
,
e2
,
v2
,
s2
)
else
:
return
mm_cpu
(
e1
,
v1
,
s1
,
e2
,
v2
,
s2
)
def
mm_cuda
(
indexA
,
valueA
,
sizeA
,
indexB
,
valueB
,
sizeB
):
A
=
torch
.
sparse_coo_tensor
(
indexA
,
valueA
,
sizeA
)
B
=
torch
.
sparse_coo_tensor
(
indexB
,
valueB
,
sizeB
)
index
,
value
=
matmul_cuda
.
spspmm
(
A
,
B
)
return
index
,
value
def
mm_cuda
(
e1
,
v1
,
s1
,
e2
,
v2
,
s2
):
matrix1
=
SparseTensor
(
e1
,
v1
,
s1
)
matrix2
=
SparseTensor
(
e2
,
v2
,
s2
)
return
matmul_cuda
.
spspmm
(
matrix1
,
matrix2
)
def
mm_cpu
(
indexA
,
valueA
,
sizeA
,
indexB
,
valueB
,
sizeB
):
A
,
B
,
=
to_scipy
(
indexA
,
valueA
,
sizeA
),
to_scipy
(
indexB
,
valueB
,
sizeB
)
C
=
A
.
tocsr
().
dot
(
B
.
tocsr
()).
tocoo
()
row
,
col
=
torch
.
from_numpy
(
C
.
row
).
long
(),
torch
.
from_numpy
(
C
.
col
).
long
()
index
=
torch
.
stack
([
row
,
col
],
dim
=
0
)
value
=
torch
.
from_numpy
(
C
.
data
).
type_as
(
valueA
)
def
mm_cpu
(
e1
,
v1
,
s1
,
e2
,
v2
,
s2
):
matrix1
,
matrix2
,
=
to_csr
(
e1
,
v1
,
s1
),
to_csr
(
e2
,
v2
,
s2
)
out
=
matrix1
.
dot
(
matrix2
).
tocoo
()
row
,
col
=
from_numpy
(
out
.
row
).
long
(),
from_numpy
(
out
.
col
).
long
()
return
torch
.
stack
([
row
,
col
],
dim
=
0
),
from_numpy
(
out
.
data
)
return
index
,
value
def
to_csr
(
index
,
value
,
size
):
index
,
value
=
index
.
detach
().
numpy
(),
value
.
detach
().
numpy
()
shape
=
(
size
[
0
],
size
[
1
])
return
coo_matrix
((
value
,
(
index
[
0
],
index
[
1
])),
shape
).
tocsr
()
def
to_scipy
(
index
,
value
,
size
):
(
row
,
col
),
value
=
index
.
detach
().
numpy
(),
value
.
detach
().
numpy
()
return
scipy
.
sparse
.
coo_matrix
((
value
,
(
row
,
col
)),
tuple
(
size
))
torch_sparse/transpose.py
0 → 100644
View file @
5788c855
import
torch
def
transpose
(
index
,
value
,
size
):
(
row
,
col
),
(
dim1
,
dim2
)
=
index
,
size
index
,
size
=
torch
.
stack
([
col
,
row
],
dim
=
0
),
torch
.
Size
([
dim2
,
dim1
])
return
index
,
value
,
size
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment