Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
64a8e2ce
Commit
64a8e2ce
authored
May 08, 2020
by
Mario Geiger
Browse files
view
parent
57852a66
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
91 additions
and
0 deletions
+91
-0
test/test_view.py
test/test_view.py
+31
-0
torch_sparse/__init__.py
torch_sparse/__init__.py
+2
-0
torch_sparse/view.py
torch_sparse/view.py
+58
-0
No files found.
test/test_view.py
0 → 100644
View file @
64a8e2ce
from
itertools
import
product
import
pytest
import
torch
from
torch_sparse.tensor
import
SparseTensor
from
torch_sparse
import
view
from
.utils
import
dtypes
,
devices
,
tensor
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
dtypes
,
devices
))
def
test_view_matrix
(
dtype
,
device
):
row
=
torch
.
tensor
([
0
,
1
,
1
],
device
=
device
)
col
=
torch
.
tensor
([
1
,
0
,
2
],
device
=
device
)
index
=
torch
.
stack
([
row
,
col
],
dim
=
0
)
value
=
tensor
([
1
,
2
,
3
],
dtype
,
device
)
index
,
value
=
view
(
index
,
value
,
m
=
2
,
n
=
3
,
new_n
=
2
)
assert
index
.
tolist
()
==
[[
0
,
1
,
2
],
[
1
,
1
,
1
]]
assert
value
.
tolist
()
==
[
1
,
2
,
3
]
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
dtypes
,
devices
))
def
test_view_sparse_tensor
(
dtype
,
device
):
options
=
torch
.
tensor
(
0
,
dtype
=
dtype
,
device
=
device
)
mat
=
SparseTensor
.
eye
(
4
,
options
=
options
).
view
(
2
,
8
)
assert
mat
.
storage
.
sparse_sizes
()
==
(
2
,
8
)
assert
mat
.
storage
.
row
().
tolist
()
==
[
0
,
0
,
1
,
1
]
assert
mat
.
storage
.
col
().
tolist
()
==
[
0
,
5
,
2
,
7
]
assert
mat
.
storage
.
value
().
tolist
()
==
[
1
,
1
,
1
,
1
]
torch_sparse/__init__.py
View file @
64a8e2ce
...
@@ -55,6 +55,7 @@ from .convert import to_torch_sparse, from_torch_sparse # noqa
...
@@ -55,6 +55,7 @@ from .convert import to_torch_sparse, from_torch_sparse # noqa
from
.convert
import
to_scipy
,
from_scipy
# noqa
from
.convert
import
to_scipy
,
from_scipy
# noqa
from
.coalesce
import
coalesce
# noqa
from
.coalesce
import
coalesce
# noqa
from
.transpose
import
transpose
# noqa
from
.transpose
import
transpose
# noqa
from
.view
import
view
# noqa
from
.eye
import
eye
# noqa
from
.eye
import
eye
# noqa
from
.spmm
import
spmm
# noqa
from
.spmm
import
spmm
# noqa
from
.spspmm
import
spspmm
# noqa
from
.spspmm
import
spspmm
# noqa
...
@@ -101,6 +102,7 @@ __all__ = [
...
@@ -101,6 +102,7 @@ __all__ = [
'from_scipy'
,
'from_scipy'
,
'coalesce'
,
'coalesce'
,
'transpose'
,
'transpose'
,
'view'
,
'eye'
,
'eye'
,
'spmm'
,
'spmm'
,
'spspmm'
,
'spspmm'
,
...
...
torch_sparse/view.py
0 → 100644
View file @
64a8e2ce
import
torch
from
torch_sparse.storage
import
SparseStorage
from
torch_sparse.tensor
import
SparseTensor
def
_view
(
src
:
SparseTensor
,
n
:
int
,
layout
:
str
=
'csr'
)
->
SparseTensor
:
row
,
col
,
value
=
src
.
coo
()
sparse_sizes
=
src
.
storage
.
sparse_sizes
()
if
sparse_sizes
[
0
]
*
sparse_sizes
[
1
]
%
n
==
0
:
raise
RuntimeError
(
f
"shape '[-1,
{
n
}
]' is invalid for input of size
{
sparse_sizes
[
0
]
*
sparse_sizes
[
1
]
}
"
)
assert
layout
==
'csr'
or
layout
==
'csc'
if
layout
==
'csr'
:
idx
=
sparse_sizes
[
1
]
*
row
+
col
row
=
idx
//
n
col
=
idx
%
n
sparse_sizes
=
(
sparse_sizes
[
0
]
*
sparse_sizes
[
1
]
//
n
,
n
)
if
layout
==
'csc'
:
idx
=
sparse_sizes
[
0
]
*
col
+
row
row
=
idx
%
n
col
=
idx
//
n
sparse_sizes
=
(
n
,
sparse_sizes
[
0
]
*
sparse_sizes
[
1
]
//
n
)
storage
=
SparseStorage
(
row
=
row
,
rowptr
=
src
.
storage
.
_rowptr
,
col
=
col
,
value
=
value
,
sparse_sizes
=
sparse_sizes
,
rowcount
=
src
.
storage
.
_rowcount
,
colptr
=
src
.
storage
.
_colptr
,
colcount
=
src
.
storage
.
_colcount
,
csr2csc
=
src
.
storage
.
_csr2csc
,
csc2csr
=
src
.
storage
.
_csc2csr
,
is_sorted
=
True
,
)
return
src
.
from_storage
(
storage
)
SparseTensor
.
view
=
lambda
self
,
m
,
n
:
_view
(
self
,
n
,
layout
=
'csr'
)
###############################################################################
def
view
(
index
,
value
,
m
,
n
,
new_n
):
assert
m
*
n
%
new_n
==
0
row
,
col
=
index
idx
=
n
*
row
+
col
row
=
idx
//
new_n
col
=
idx
%
new_n
return
torch
.
stack
([
row
,
col
],
dim
=
0
),
value
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment