Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
9732a518
Commit
9732a518
authored
Mar 22, 2019
by
rusty1s
Browse files
torch sparse convert + transpose cleanup
parent
0b790779
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
59 additions
and
44 deletions
+59
-44
test/test_convert.py
test/test_convert.py
+23
-0
test/test_transpose.py
test/test_transpose.py
+14
-12
torch_sparse/__init__.py
torch_sparse/__init__.py
+4
-3
torch_sparse/convert.py
torch_sparse/convert.py
+8
-0
torch_sparse/spspmm.py
torch_sparse/spspmm.py
+4
-5
torch_sparse/transpose.py
torch_sparse/transpose.py
+6
-24
No files found.
test/test_convert.py
0 → 100644
View file @
9732a518
import
torch
from
torch_sparse
import
to_scipy
,
from_scipy
from
torch_sparse
import
to_torch_sparse
,
from_torch_sparse
def
test_convert_scipy
():
index
=
torch
.
tensor
([[
0
,
0
,
1
,
2
,
2
],
[
0
,
2
,
1
,
0
,
1
]])
value
=
torch
.
Tensor
([
1
,
2
,
4
,
1
,
3
])
N
=
3
out
=
from_scipy
(
to_scipy
(
index
,
value
,
N
,
N
))
assert
out
[
0
].
tolist
()
==
index
.
tolist
()
assert
out
[
1
].
tolist
()
==
value
.
tolist
()
def
test_convert_torch_sparse
():
index
=
torch
.
tensor
([[
0
,
0
,
1
,
2
,
2
],
[
0
,
2
,
1
,
0
,
1
]])
value
=
torch
.
Tensor
([
1
,
2
,
4
,
1
,
3
])
N
=
3
out
=
from_torch_sparse
(
to_torch_sparse
(
index
,
value
,
N
,
N
).
coalesce
())
assert
out
[
0
].
tolist
()
==
index
.
tolist
()
assert
out
[
1
].
tolist
()
==
value
.
tolist
()
test/test_transpose.py
View file @
9732a518
...
@@ -2,29 +2,31 @@ from itertools import product
...
@@ -2,29 +2,31 @@ from itertools import product
import
pytest
import
pytest
import
torch
import
torch
from
torch_sparse
import
transpose
,
transpose_matrix
from
torch_sparse
import
transpose
from
.utils
import
dtypes
,
devices
,
tensor
from
.utils
import
dtypes
,
devices
,
tensor
def
test_transpose
():
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
dtypes
,
devices
))
row
=
torch
.
tensor
([
1
,
0
,
1
,
0
,
2
,
1
])
def
test_transpose_matrix
(
dtype
,
device
):
col
=
torch
.
tensor
([
0
,
1
,
1
,
1
,
0
,
0
])
row
=
torch
.
tensor
([
1
,
0
,
1
,
2
],
device
=
device
)
col
=
torch
.
tensor
([
0
,
1
,
1
,
0
],
device
=
device
)
index
=
torch
.
stack
([
row
,
col
],
dim
=
0
)
index
=
torch
.
stack
([
row
,
col
],
dim
=
0
)
value
=
torch
.
tensor
([
[
1
,
2
],
[
2
,
3
],
[
3
,
4
],
[
4
,
5
],
[
5
,
6
],
[
6
,
7
]]
)
value
=
tensor
([
1
,
2
,
3
,
4
],
dtype
,
device
)
index
,
value
=
transpose
(
index
,
value
,
m
=
3
,
n
=
2
)
index
,
value
=
transpose
(
index
,
value
,
m
=
3
,
n
=
2
)
assert
index
.
tolist
()
==
[[
0
,
0
,
1
,
1
],
[
1
,
2
,
0
,
1
]]
assert
index
.
tolist
()
==
[[
0
,
0
,
1
,
1
],
[
1
,
2
,
0
,
1
]]
assert
value
.
tolist
()
==
[
[
7
,
9
],
[
5
,
6
],
[
6
,
8
],
[
3
,
4
]
]
assert
value
.
tolist
()
==
[
1
,
4
,
2
,
3
]
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
dtypes
,
devices
))
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
dtypes
,
devices
))
def
test_transpose
_matrix
(
dtype
,
device
):
def
test_transpose
(
dtype
,
device
):
row
=
torch
.
tensor
([
1
,
0
,
1
,
2
],
device
=
device
)
row
=
torch
.
tensor
([
1
,
0
,
1
,
0
,
2
,
1
],
device
=
device
)
col
=
torch
.
tensor
([
0
,
1
,
1
,
0
],
device
=
device
)
col
=
torch
.
tensor
([
0
,
1
,
1
,
1
,
0
,
0
],
device
=
device
)
index
=
torch
.
stack
([
row
,
col
],
dim
=
0
)
index
=
torch
.
stack
([
row
,
col
],
dim
=
0
)
value
=
tensor
([
1
,
2
,
3
,
4
],
dtype
,
device
)
value
=
tensor
([[
1
,
2
],
[
2
,
3
],
[
3
,
4
],
[
4
,
5
],
[
5
,
6
],
[
6
,
7
]],
dtype
,
device
)
index
,
value
=
transpose
_matrix
(
index
,
value
,
m
=
3
,
n
=
2
)
index
,
value
=
transpose
(
index
,
value
,
m
=
3
,
n
=
2
)
assert
index
.
tolist
()
==
[[
0
,
0
,
1
,
1
],
[
1
,
2
,
0
,
1
]]
assert
index
.
tolist
()
==
[[
0
,
0
,
1
,
1
],
[
1
,
2
,
0
,
1
]]
assert
value
.
tolist
()
==
[
1
,
4
,
2
,
3
]
assert
value
.
tolist
()
==
[
[
7
,
9
],
[
5
,
6
],
[
6
,
8
],
[
3
,
4
]
]
torch_sparse/__init__.py
View file @
9732a518
from
.convert
import
to_scipy
,
from_scipy
from
.convert
import
to_torch_sparse
,
from_torch_sparse
,
to_scipy
,
from_scipy
from
.coalesce
import
coalesce
from
.coalesce
import
coalesce
from
.transpose
import
transpose
,
transpose_matrix
from
.transpose
import
transpose
from
.eye
import
eye
from
.eye
import
eye
from
.spmm
import
spmm
from
.spmm
import
spmm
from
.spspmm
import
spspmm
from
.spspmm
import
spspmm
...
@@ -9,11 +9,12 @@ __version__ = '0.3.0'
...
@@ -9,11 +9,12 @@ __version__ = '0.3.0'
__all__
=
[
__all__
=
[
'__version__'
,
'__version__'
,
'to_torch_sparse'
,
'from_torch_sparse'
,
'to_scipy'
,
'to_scipy'
,
'from_scipy'
,
'from_scipy'
,
'coalesce'
,
'coalesce'
,
'transpose'
,
'transpose'
,
'transpose_matrix'
,
'eye'
,
'eye'
,
'spmm'
,
'spmm'
,
'spspmm'
,
'spspmm'
,
...
...
torch_sparse/convert.py
View file @
9732a518
...
@@ -4,6 +4,14 @@ import torch
...
@@ -4,6 +4,14 @@ import torch
from
torch
import
from_numpy
from
torch
import
from_numpy
def
to_torch_sparse
(
index
,
value
,
m
,
n
):
return
torch
.
sparse_coo_tensor
(
index
.
detach
(),
value
,
torch
.
Size
([
m
,
n
]))
def
from_torch_sparse
(
A
):
return
A
.
indices
().
detach
(),
A
.
values
()
def
to_scipy
(
index
,
value
,
m
,
n
):
def
to_scipy
(
index
,
value
,
m
,
n
):
assert
not
index
.
is_cuda
and
not
value
.
is_cuda
assert
not
index
.
is_cuda
and
not
value
.
is_cuda
(
row
,
col
),
data
=
index
.
detach
(),
value
.
detach
()
(
row
,
col
),
data
=
index
.
detach
(),
value
.
detach
()
...
...
torch_sparse/spspmm.py
View file @
9732a518
import
torch
import
torch
from
torch_sparse
import
transpose
_matrix
,
to_scipy
,
from_scipy
from
torch_sparse
import
transpose
,
to_scipy
,
from_scipy
import
torch_sparse.spspmm_cpu
import
torch_sparse.spspmm_cpu
...
@@ -53,9 +53,8 @@ class SpSpMM(torch.autograd.Function):
...
@@ -53,9 +53,8 @@ class SpSpMM(torch.autograd.Function):
valueB
,
m
,
k
)
valueB
,
m
,
k
)
if
ctx
.
needs_input_grad
[
3
]:
if
ctx
.
needs_input_grad
[
3
]:
indexA
,
valueA
=
transpose_matrix
(
indexA
,
valueA
,
m
,
k
)
indexA
,
valueA
=
transpose
(
indexA
,
valueA
,
m
,
k
)
indexC
,
grad_valueC
=
transpose_matrix
(
indexC
,
grad_valueC
,
m
,
indexC
,
grad_valueC
=
transpose
(
indexC
,
grad_valueC
,
m
,
n
)
n
)
grad_valueB
=
torch_sparse
.
spspmm_cpu
.
spspmm_bw
(
grad_valueB
=
torch_sparse
.
spspmm_cpu
.
spspmm_bw
(
indexB
,
indexA
.
detach
(),
valueA
,
indexC
.
detach
(),
indexB
,
indexA
.
detach
(),
valueA
,
indexC
.
detach
(),
grad_valueC
,
k
,
n
)
grad_valueC
,
k
,
n
)
...
@@ -66,7 +65,7 @@ class SpSpMM(torch.autograd.Function):
...
@@ -66,7 +65,7 @@ class SpSpMM(torch.autograd.Function):
indexB
.
detach
(),
valueB
,
m
,
k
)
indexB
.
detach
(),
valueB
,
m
,
k
)
if
ctx
.
needs_input_grad
[
3
]:
if
ctx
.
needs_input_grad
[
3
]:
indexA_T
,
valueA_T
=
transpose
_matrix
(
indexA
,
valueA
,
m
,
k
)
indexA_T
,
valueA_T
=
transpose
(
indexA
,
valueA
,
m
,
k
)
grad_indexB
,
grad_valueB
=
mm
(
indexA_T
,
valueA_T
,
indexC
,
grad_indexB
,
grad_valueB
=
mm
(
indexA_T
,
valueA_T
,
indexC
,
grad_valueC
,
k
,
m
,
n
)
grad_valueC
,
k
,
m
,
n
)
grad_valueB
=
lift
(
grad_indexB
,
grad_valueB
,
indexB
,
n
)
grad_valueB
=
lift
(
grad_indexB
,
grad_valueB
,
indexB
,
n
)
...
...
torch_sparse/transpose.py
View file @
9732a518
...
@@ -14,31 +14,13 @@ def transpose(index, value, m, n):
...
@@ -14,31 +14,13 @@ def transpose(index, value, m, n):
:rtype: (:class:`LongTensor`, :class:`Tensor`)
:rtype: (:class:`LongTensor`, :class:`Tensor`)
"""
"""
row
,
col
=
index
if
value
.
dim
()
==
1
and
not
value
.
is_cuda
:
index
=
torch
.
stack
([
col
,
row
],
dim
=
0
)
index
,
value
=
coalesce
(
index
,
value
,
n
,
m
)
return
index
,
value
def
transpose_matrix
(
index
,
value
,
m
,
n
):
"""Transposes dimensions 0 and 1 of a sparse matrix, where :args:`value` is
one-dimensional.
Args:
index (:class:`LongTensor`): The index tensor of sparse matrix.
value (:class:`Tensor`): The value tensor of sparse matrix.
m (int): The first dimension of sparse matrix.
n (int): The second dimension of sparse matrix.
:rtype: (:class:`LongTensor`, :class:`Tensor`)
"""
assert
value
.
dim
()
==
1
if
index
.
is_cuda
:
return
transpose
(
index
,
value
,
m
,
n
)
else
:
mat
=
to_scipy
(
index
,
value
,
m
,
n
).
tocsc
()
mat
=
to_scipy
(
index
,
value
,
m
,
n
).
tocsc
()
(
col
,
row
),
value
=
from_scipy
(
mat
)
(
col
,
row
),
value
=
from_scipy
(
mat
)
index
=
torch
.
stack
([
row
,
col
],
dim
=
0
)
index
=
torch
.
stack
([
row
,
col
],
dim
=
0
)
return
index
,
value
return
index
,
value
row
,
col
=
index
index
=
torch
.
stack
([
col
,
row
],
dim
=
0
)
index
,
value
=
coalesce
(
index
,
value
,
n
,
m
)
return
index
,
value
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment