Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
0e2ddfad
Commit
0e2ddfad
authored
May 11, 2020
by
rusty1s
Browse files
added view to storage + rename
parent
4dec4df0
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
50 additions
and
88 deletions
+50
-88
test/test_storage.py
test/test_storage.py
+21
-0
test/test_view.py
test/test_view.py
+0
-31
torch_sparse/__init__.py
torch_sparse/__init__.py
+0
-2
torch_sparse/storage.py
torch_sparse/storage.py
+25
-0
torch_sparse/tensor.py
torch_sparse/tensor.py
+4
-0
torch_sparse/view.py
torch_sparse/view.py
+0
-55
No files found.
test/test_storage.py
View file @
0e2ddfad
...
@@ -122,3 +122,24 @@ def test_coalesce(dtype, device):
...
@@ -122,3 +122,24 @@ def test_coalesce(dtype, device):
assert
storage
.
row
().
tolist
()
==
[
0
,
0
,
1
,
1
]
assert
storage
.
row
().
tolist
()
==
[
0
,
0
,
1
,
1
]
assert
storage
.
col
().
tolist
()
==
[
0
,
1
,
0
,
1
]
assert
storage
.
col
().
tolist
()
==
[
0
,
1
,
0
,
1
]
assert
storage
.
value
().
tolist
()
==
[
1
,
2
,
3
,
4
]
assert
storage
.
value
().
tolist
()
==
[
1
,
2
,
3
,
4
]
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
dtypes
,
devices
))
def
test_sparse_reshape
(
dtype
,
device
):
row
,
col
=
tensor
([[
0
,
1
,
2
,
3
],
[
0
,
1
,
2
,
3
]],
torch
.
long
,
device
)
storage
=
SparseStorage
(
row
=
row
,
col
=
col
)
storage
=
storage
.
sparse_reshape
(
2
,
8
)
assert
storage
.
sparse_sizes
()
==
(
2
,
8
)
assert
storage
.
row
().
tolist
()
==
[
0
,
0
,
1
,
1
]
assert
storage
.
col
().
tolist
()
==
[
0
,
5
,
2
,
7
]
storage
=
storage
.
sparse_reshape
(
-
1
,
4
)
assert
storage
.
sparse_sizes
()
==
(
4
,
4
)
assert
storage
.
row
().
tolist
()
==
[
0
,
1
,
2
,
3
]
assert
storage
.
col
().
tolist
()
==
[
0
,
1
,
2
,
3
]
storage
=
storage
.
sparse_reshape
(
2
,
-
1
)
assert
storage
.
sparse_sizes
()
==
(
2
,
8
)
assert
storage
.
row
().
tolist
()
==
[
0
,
0
,
1
,
1
]
assert
storage
.
col
().
tolist
()
==
[
0
,
5
,
2
,
7
]
test/test_view.py
deleted
100644 → 0
View file @
4dec4df0
from
itertools
import
product
import
pytest
import
torch
from
torch_sparse.tensor
import
SparseTensor
from
torch_sparse
import
view
from
.utils
import
dtypes
,
devices
,
tensor
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
dtypes
,
devices
))
def
test_view_matrix
(
dtype
,
device
):
row
=
torch
.
tensor
([
0
,
1
,
1
],
device
=
device
)
col
=
torch
.
tensor
([
1
,
0
,
2
],
device
=
device
)
index
=
torch
.
stack
([
row
,
col
],
dim
=
0
)
value
=
tensor
([
1
,
2
,
3
],
dtype
,
device
)
index
,
value
=
view
(
index
,
value
,
m
=
2
,
n
=
3
,
new_n
=
2
)
assert
index
.
tolist
()
==
[[
0
,
1
,
2
],
[
1
,
1
,
1
]]
assert
value
.
tolist
()
==
[
1
,
2
,
3
]
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
dtypes
,
devices
))
def
test_view_sparse_tensor
(
dtype
,
device
):
options
=
torch
.
tensor
(
0
,
dtype
=
dtype
,
device
=
device
)
mat
=
SparseTensor
.
eye
(
4
,
options
=
options
).
view
(
2
,
8
)
assert
mat
.
storage
.
sparse_sizes
()
==
(
2
,
8
)
assert
mat
.
storage
.
row
().
tolist
()
==
[
0
,
0
,
1
,
1
]
assert
mat
.
storage
.
col
().
tolist
()
==
[
0
,
5
,
2
,
7
]
assert
mat
.
storage
.
value
().
tolist
()
==
[
1
,
1
,
1
,
1
]
torch_sparse/__init__.py
View file @
0e2ddfad
...
@@ -55,7 +55,6 @@ from .convert import to_torch_sparse, from_torch_sparse # noqa
...
@@ -55,7 +55,6 @@ from .convert import to_torch_sparse, from_torch_sparse # noqa
from
.convert
import
to_scipy
,
from_scipy
# noqa
from
.convert
import
to_scipy
,
from_scipy
# noqa
from
.coalesce
import
coalesce
# noqa
from
.coalesce
import
coalesce
# noqa
from
.transpose
import
transpose
# noqa
from
.transpose
import
transpose
# noqa
from
.view
import
view
# noqa
from
.eye
import
eye
# noqa
from
.eye
import
eye
# noqa
from
.spmm
import
spmm
# noqa
from
.spmm
import
spmm
# noqa
from
.spspmm
import
spspmm
# noqa
from
.spspmm
import
spspmm
# noqa
...
@@ -102,7 +101,6 @@ __all__ = [
...
@@ -102,7 +101,6 @@ __all__ = [
'from_scipy'
,
'from_scipy'
,
'coalesce'
,
'coalesce'
,
'transpose'
,
'transpose'
,
'view'
,
'eye'
,
'eye'
,
'spmm'
,
'spmm'
,
'spspmm'
,
'spspmm'
,
...
...
torch_sparse/storage.py
View file @
0e2ddfad
...
@@ -260,6 +260,31 @@ class SparseStorage(object):
...
@@ -260,6 +260,31 @@ class SparseStorage(object):
colcount
=
colcount
,
csr2csc
=
self
.
_csr2csc
,
colcount
=
colcount
,
csr2csc
=
self
.
_csr2csc
,
csc2csr
=
self
.
_csc2csr
,
is_sorted
=
True
)
csc2csr
=
self
.
_csc2csr
,
is_sorted
=
True
)
def
sparse_reshape
(
self
,
num_rows
:
int
,
num_cols
:
int
):
assert
num_rows
>
0
or
num_rows
==
-
1
assert
num_cols
>
0
or
num_cols
==
-
1
assert
num_rows
>
0
or
num_cols
>
0
total
=
self
.
sparse_size
(
0
)
*
self
.
sparse_size
(
1
)
if
num_rows
==
-
1
:
num_rows
=
total
//
num_cols
if
num_cols
==
-
1
:
num_cols
=
total
//
num_rows
assert
num_rows
*
num_cols
==
total
idx
=
self
.
sparse_size
(
1
)
*
self
.
row
()
+
self
.
col
()
row
=
idx
/
num_cols
col
=
idx
%
num_cols
return
SparseStorage
(
row
=
row
,
rowptr
=
None
,
col
=
col
,
value
=
self
.
_value
,
sparse_sizes
=
(
num_rows
,
num_cols
),
rowcount
=
None
,
colptr
=
None
,
colcount
=
None
,
csr2csc
=
None
,
csc2csr
=
None
,
is_sorted
=
True
)
def
has_rowcount
(
self
)
->
bool
:
def
has_rowcount
(
self
)
->
bool
:
return
self
.
_rowcount
is
not
None
return
self
.
_rowcount
is
not
None
...
...
torch_sparse/tensor.py
View file @
0e2ddfad
...
@@ -171,6 +171,10 @@ class SparseTensor(object):
...
@@ -171,6 +171,10 @@ class SparseTensor(object):
def
sparse_resize
(
self
,
sparse_sizes
:
Tuple
[
int
,
int
]):
def
sparse_resize
(
self
,
sparse_sizes
:
Tuple
[
int
,
int
]):
return
self
.
from_storage
(
self
.
storage
.
sparse_resize
(
sparse_sizes
))
return
self
.
from_storage
(
self
.
storage
.
sparse_resize
(
sparse_sizes
))
def
sparse_reshape
(
self
,
num_rows
:
int
,
num_cols
:
int
):
return
self
.
from_storage
(
self
.
storage
.
sparse_reshape
(
num_rows
,
num_cols
))
def
is_coalesced
(
self
)
->
bool
:
def
is_coalesced
(
self
)
->
bool
:
return
self
.
storage
.
is_coalesced
()
return
self
.
storage
.
is_coalesced
()
...
...
torch_sparse/view.py
deleted
100644 → 0
View file @
4dec4df0
import
torch
from
torch_sparse.storage
import
SparseStorage
from
torch_sparse.tensor
import
SparseTensor
def
_view
(
src
:
SparseTensor
,
n
:
int
,
layout
:
str
=
'csr'
)
->
SparseTensor
:
row
,
col
,
value
=
src
.
coo
()
sparse_sizes
=
src
.
storage
.
sparse_sizes
()
if
sparse_sizes
[
0
]
*
sparse_sizes
[
1
]
%
n
!=
0
:
raise
RuntimeError
(
f
"shape '[-1,
{
n
}
]' is invalid for input of size "
f
"
{
sparse_sizes
[
0
]
*
sparse_sizes
[
1
]
}
"
)
assert
layout
==
'csr'
or
layout
==
'csc'
if
layout
==
'csr'
:
idx
=
sparse_sizes
[
1
]
*
row
+
col
row
=
idx
//
n
col
=
idx
%
n
sparse_sizes
=
(
sparse_sizes
[
0
]
*
sparse_sizes
[
1
]
//
n
,
n
)
if
layout
==
'csc'
:
idx
=
sparse_sizes
[
0
]
*
col
+
row
row
=
idx
%
n
col
=
idx
//
n
sparse_sizes
=
(
n
,
sparse_sizes
[
0
]
*
sparse_sizes
[
1
]
//
n
)
storage
=
SparseStorage
(
row
=
row
,
col
=
col
,
value
=
value
,
sparse_sizes
=
sparse_sizes
,
csr2csc
=
src
.
storage
.
_csr2csc
,
csc2csr
=
src
.
storage
.
_csc2csr
,
is_sorted
=
True
,
)
return
src
.
from_storage
(
storage
)
SparseTensor
.
view
=
lambda
self
,
m
,
n
:
_view
(
self
,
n
,
layout
=
'csr'
)
###############################################################################
def
view
(
index
,
value
,
m
,
n
,
new_n
):
assert
m
*
n
%
new_n
==
0
row
,
col
=
index
idx
=
n
*
row
+
col
row
=
idx
//
new_n
col
=
idx
%
new_n
return
torch
.
stack
([
row
,
col
],
dim
=
0
),
value
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment