Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
6e9cd9d6
Commit
6e9cd9d6
authored
Feb 19, 2020
by
rusty1s
Browse files
sparse size to tuple
parent
e6f5c3f0
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
38 additions
and
35 deletions
+38
-35
test/test_storage.py
test/test_storage.py
+2
-2
torch_sparse/coalesce.py
torch_sparse/coalesce.py
+1
-1
torch_sparse/convert.py
torch_sparse/convert.py
+1
-1
torch_sparse/index_select.py
torch_sparse/index_select.py
+2
-2
torch_sparse/masked_select.py
torch_sparse/masked_select.py
+2
-2
torch_sparse/matmul.py
torch_sparse/matmul.py
+1
-1
torch_sparse/narrow.py
torch_sparse/narrow.py
+2
-2
torch_sparse/spspmm.py
torch_sparse/spspmm.py
+2
-2
torch_sparse/storage.py
torch_sparse/storage.py
+7
-7
torch_sparse/tensor.py
torch_sparse/tensor.py
+16
-13
torch_sparse/transpose.py
torch_sparse/transpose.py
+2
-2
No files found.
test/test_storage.py
View file @
6e9cd9d6
...
...
@@ -93,8 +93,8 @@ def test_utility(dtype, device):
storage
=
storage
.
set_value
(
value
,
layout
=
'coo'
)
assert
storage
.
value
().
tolist
()
==
[
1
,
2
,
3
,
4
]
storage
=
storage
.
sparse_resize
(
[
3
,
3
]
)
assert
storage
.
sparse_sizes
()
==
[
3
,
3
]
storage
=
storage
.
sparse_resize
(
(
3
,
3
)
)
assert
storage
.
sparse_sizes
()
==
(
3
,
3
)
new_storage
=
storage
.
copy
()
assert
new_storage
!=
storage
...
...
torch_sparse/coalesce.py
View file @
6e9cd9d6
...
...
@@ -20,6 +20,6 @@ def coalesce(index, value, m, n, op="add"):
"""
storage
=
SparseStorage
(
row
=
index
[
0
],
col
=
index
[
1
],
value
=
value
,
sparse_sizes
=
torch
.
Size
([
m
,
n
]
),
is_sorted
=
False
)
sparse_sizes
=
(
m
,
n
),
is_sorted
=
False
)
storage
=
storage
.
coalesce
(
reduce
=
op
)
return
torch
.
stack
([
storage
.
row
(),
storage
.
col
()],
dim
=
0
),
storage
.
value
()
torch_sparse/convert.py
View file @
6e9cd9d6
...
...
@@ -5,7 +5,7 @@ from torch import from_numpy
def
to_torch_sparse
(
index
,
value
,
m
,
n
):
return
torch
.
sparse_coo_tensor
(
index
.
detach
(),
value
,
torch
.
Size
([
m
,
n
]
))
return
torch
.
sparse_coo_tensor
(
index
.
detach
(),
value
,
(
m
,
n
))
def
from_torch_sparse
(
A
):
...
...
torch_sparse/index_select.py
View file @
6e9cd9d6
...
...
@@ -31,7 +31,7 @@ def index_select(src: SparseTensor, dim: int,
if
value
is
not
None
:
value
=
value
[
perm
]
sparse_sizes
=
torch
.
Size
([
idx
.
size
(
0
),
src
.
sparse_size
(
1
)
]
)
sparse_sizes
=
(
idx
.
size
(
0
),
src
.
sparse_size
(
1
))
storage
=
SparseStorage
(
row
=
row
,
rowptr
=
rowptr
,
col
=
col
,
value
=
value
,
sparse_sizes
=
sparse_sizes
,
rowcount
=
rowcount
,
...
...
@@ -61,7 +61,7 @@ def index_select(src: SparseTensor, dim: int,
if
value
is
not
None
:
value
=
value
[
perm
][
csc2csr
]
sparse_sizes
=
torch
.
Size
([
src
.
sparse_size
(
0
),
idx
.
size
(
0
)
]
)
sparse_sizes
=
(
src
.
sparse_size
(
0
),
idx
.
size
(
0
))
storage
=
SparseStorage
(
row
=
row
,
rowptr
=
None
,
col
=
col
,
value
=
value
,
sparse_sizes
=
sparse_sizes
,
rowcount
=
None
,
...
...
torch_sparse/masked_select.py
View file @
6e9cd9d6
...
...
@@ -27,7 +27,7 @@ def masked_select(src: SparseTensor, dim: int,
if
value
is
not
None
:
value
=
value
[
mask
]
sparse_sizes
=
torch
.
Size
([
rowcount
.
size
(
0
),
src
.
sparse_size
(
1
)
]
)
sparse_sizes
=
(
rowcount
.
size
(
0
),
src
.
sparse_size
(
1
))
storage
=
SparseStorage
(
row
=
row
,
rowptr
=
None
,
col
=
col
,
value
=
value
,
sparse_sizes
=
sparse_sizes
,
rowcount
=
rowcount
,
...
...
@@ -54,7 +54,7 @@ def masked_select(src: SparseTensor, dim: int,
if
value
is
not
None
:
value
=
value
[
csr2csc
][
mask
][
csc2csr
]
sparse_sizes
=
torch
.
Size
([
src
.
sparse_size
(
0
),
colcount
.
size
(
0
)
]
)
sparse_sizes
=
(
src
.
sparse_size
(
0
),
colcount
.
size
(
0
))
storage
=
SparseStorage
(
row
=
row
,
rowptr
=
None
,
col
=
col
,
value
=
value
,
sparse_sizes
=
sparse_sizes
,
rowcount
=
None
,
...
...
torch_sparse/matmul.py
View file @
6e9cd9d6
...
...
@@ -82,7 +82,7 @@ def spspmm_sum(src: SparseTensor, other: SparseTensor) -> SparseTensor:
rowptrC
,
colC
,
valueC
=
torch
.
ops
.
torch_sparse
.
spspmm_sum
(
rowptrA
,
colA
,
valueA
,
rowptrB
,
colB
,
valueB
,
K
)
return
SparseTensor
(
row
=
None
,
rowptr
=
rowptrC
,
col
=
colC
,
value
=
valueC
,
sparse_sizes
=
torch
.
Size
([
M
,
K
]
),
is_sorted
=
True
)
sparse_sizes
=
(
M
,
K
),
is_sorted
=
True
)
def
spspmm_add
(
src
:
SparseTensor
,
other
:
SparseTensor
)
->
SparseTensor
:
...
...
torch_sparse/narrow.py
View file @
6e9cd9d6
...
...
@@ -30,7 +30,7 @@ def narrow(src: SparseTensor, dim: int, start: int,
if
value
is
not
None
:
value
=
value
.
narrow
(
0
,
row_start
,
row_length
)
sparse_sizes
=
torch
.
Size
([
length
,
src
.
sparse_size
(
1
)
]
)
sparse_sizes
=
(
length
,
src
.
sparse_size
(
1
))
rowcount
=
src
.
storage
.
_rowcount
if
rowcount
is
not
None
:
...
...
@@ -53,7 +53,7 @@ def narrow(src: SparseTensor, dim: int, start: int,
if
value
is
not
None
:
value
=
value
[
mask
]
sparse_sizes
=
torch
.
Size
([
src
.
sparse_size
(
0
),
length
]
)
sparse_sizes
=
(
src
.
sparse_size
(
0
),
length
)
colptr
=
src
.
storage
.
_colptr
if
colptr
is
not
None
:
...
...
torch_sparse/spspmm.py
View file @
6e9cd9d6
...
...
@@ -23,9 +23,9 @@ def spspmm(indexA, valueA, indexB, valueB, m, k, n, coalesced=False):
"""
A
=
SparseTensor
(
row
=
indexA
[
0
],
col
=
indexA
[
1
],
value
=
valueA
,
sparse_sizes
=
torch
.
Size
([
m
,
k
]
),
is_sorted
=
not
coalesced
)
sparse_sizes
=
(
m
,
k
),
is_sorted
=
not
coalesced
)
B
=
SparseTensor
(
row
=
indexB
[
0
],
col
=
indexB
[
1
],
value
=
valueB
,
sparse_sizes
=
torch
.
Size
([
k
,
n
]
),
is_sorted
=
not
coalesced
)
sparse_sizes
=
(
k
,
n
),
is_sorted
=
not
coalesced
)
C
=
matmul
(
A
,
B
)
row
,
col
,
value
=
C
.
coo
()
...
...
torch_sparse/storage.py
View file @
6e9cd9d6
import
warnings
from
typing
import
Optional
,
List
from
typing
import
Optional
,
List
,
Tuple
import
torch
from
torch_scatter
import
segment_csr
,
scatter_add
...
...
@@ -23,7 +23,7 @@ class SparseStorage(object):
_rowptr
:
Optional
[
torch
.
Tensor
]
_col
:
torch
.
Tensor
_value
:
Optional
[
torch
.
Tensor
]
_sparse_sizes
:
List
[
int
]
_sparse_sizes
:
Tuple
[
int
,
int
]
_rowcount
:
Optional
[
torch
.
Tensor
]
_colptr
:
Optional
[
torch
.
Tensor
]
_colcount
:
Optional
[
torch
.
Tensor
]
...
...
@@ -34,7 +34,7 @@ class SparseStorage(object):
rowptr
:
Optional
[
torch
.
Tensor
]
=
None
,
col
:
Optional
[
torch
.
Tensor
]
=
None
,
value
:
Optional
[
torch
.
Tensor
]
=
None
,
sparse_sizes
:
Optional
[
List
[
int
]]
=
None
,
sparse_sizes
:
Optional
[
Tuple
[
int
,
int
]]
=
None
,
rowcount
:
Optional
[
torch
.
Tensor
]
=
None
,
colptr
:
Optional
[
torch
.
Tensor
]
=
None
,
colcount
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
@@ -56,7 +56,7 @@ class SparseStorage(object):
else
:
raise
ValueError
N
=
col
.
max
().
item
()
+
1
sparse_sizes
=
torch
.
Size
([
int
(
M
),
int
(
N
)
]
)
sparse_sizes
=
(
int
(
M
),
int
(
N
))
else
:
assert
len
(
sparse_sizes
)
==
2
...
...
@@ -118,7 +118,7 @@ class SparseStorage(object):
self
.
_rowptr
=
rowptr
self
.
_col
=
col
self
.
_value
=
value
self
.
_sparse_sizes
=
sparse_sizes
self
.
_sparse_sizes
=
tuple
(
sparse_sizes
)
self
.
_rowcount
=
rowcount
self
.
_colptr
=
colptr
self
.
_colcount
=
colcount
...
...
@@ -218,13 +218,13 @@ class SparseStorage(object):
colcount
=
self
.
_colcount
,
csr2csc
=
self
.
_csr2csc
,
csc2csr
=
self
.
_csc2csr
,
is_sorted
=
True
)
def
sparse_sizes
(
self
)
->
List
[
int
]:
def
sparse_sizes
(
self
)
->
Tuple
[
int
,
int
]:
return
self
.
_sparse_sizes
def
sparse_size
(
self
,
dim
:
int
)
->
int
:
return
self
.
_sparse_sizes
[
dim
]
def
sparse_resize
(
self
,
sparse_sizes
:
List
[
int
]):
def
sparse_resize
(
self
,
sparse_sizes
:
Tuple
[
int
,
int
]):
assert
len
(
sparse_sizes
)
==
2
old_sparse_sizes
,
nnz
=
self
.
_sparse_sizes
,
self
.
_col
.
numel
()
...
...
torch_sparse/tensor.py
View file @
6e9cd9d6
...
...
@@ -16,7 +16,8 @@ class SparseTensor(object):
rowptr
:
Optional
[
torch
.
Tensor
]
=
None
,
col
:
Optional
[
torch
.
Tensor
]
=
None
,
value
:
Optional
[
torch
.
Tensor
]
=
None
,
sparse_sizes
:
List
[
int
]
=
None
,
is_sorted
:
bool
=
False
):
sparse_sizes
:
Optional
[
Tuple
[
int
,
int
]]
=
None
,
is_sorted
:
bool
=
False
):
self
.
storage
=
SparseStorage
(
row
=
row
,
rowptr
=
rowptr
,
col
=
col
,
value
=
value
,
sparse_sizes
=
sparse_sizes
,
rowcount
=
None
,
colptr
=
None
,
colcount
=
None
,
...
...
@@ -45,7 +46,8 @@ class SparseTensor(object):
value
=
mat
[
row
,
col
]
return
SparseTensor
(
row
=
row
,
rowptr
=
None
,
col
=
col
,
value
=
value
,
sparse_sizes
=
mat
.
size
()[:
2
],
is_sorted
=
True
)
sparse_sizes
=
(
mat
.
size
(
0
),
mat
.
size
(
1
)),
is_sorted
=
True
)
@
classmethod
def
from_torch_sparse_coo_tensor
(
self
,
mat
:
torch
.
Tensor
,
...
...
@@ -59,7 +61,8 @@ class SparseTensor(object):
value
=
mat
.
_values
()
return
SparseTensor
(
row
=
row
,
rowptr
=
None
,
col
=
col
,
value
=
value
,
sparse_sizes
=
mat
.
size
()[:
2
],
is_sorted
=
True
)
sparse_sizes
=
(
mat
.
size
(
0
),
mat
.
size
(
1
)),
is_sorted
=
True
)
@
classmethod
def
eye
(
self
,
M
:
int
,
N
:
Optional
[
int
]
=
None
,
...
...
@@ -105,10 +108,9 @@ class SparseTensor(object):
csr2csc
=
csc2csr
=
row
storage
:
SparseStorage
=
SparseStorage
(
row
=
row
,
rowptr
=
rowptr
,
col
=
col
,
value
=
value
,
sparse_sizes
=
torch
.
Size
([
M
,
N
]),
rowcount
=
rowcount
,
colptr
=
colptr
,
colcount
=
colcount
,
csr2csc
=
csr2csc
,
csc2csr
=
csc2csr
,
is_sorted
=
True
)
row
=
row
,
rowptr
=
rowptr
,
col
=
col
,
value
=
value
,
sparse_sizes
=
(
M
,
N
),
rowcount
=
rowcount
,
colptr
=
colptr
,
colcount
=
colcount
,
csr2csc
=
csr2csc
,
csc2csr
=
csc2csr
,
is_sorted
=
True
)
self
=
SparseTensor
.
__new__
(
SparseTensor
)
self
.
storage
=
storage
...
...
@@ -160,13 +162,13 @@ class SparseTensor(object):
layout
:
Optional
[
str
]
=
None
):
return
self
.
from_storage
(
self
.
storage
.
set_value
(
value
,
layout
))
def
sparse_sizes
(
self
)
->
List
[
int
]:
def
sparse_sizes
(
self
)
->
Tuple
[
int
,
int
]:
return
self
.
storage
.
sparse_sizes
()
def
sparse_size
(
self
,
dim
:
int
)
->
int
:
return
self
.
storage
.
sparse_sizes
()[
dim
]
def
sparse_resize
(
self
,
sparse_sizes
:
List
[
int
]):
def
sparse_resize
(
self
,
sparse_sizes
:
Tuple
[
int
,
int
]):
return
self
.
from_storage
(
self
.
storage
.
sparse_resize
(
sparse_sizes
))
def
is_coalesced
(
self
)
->
bool
:
...
...
@@ -206,11 +208,12 @@ class SparseTensor(object):
return
self
.
set_value
(
value
,
layout
=
'coo'
)
def
sizes
(
self
)
->
List
[
int
]:
sizes
=
self
.
sparse_sizes
()
sparse_
sizes
=
self
.
sparse_sizes
()
value
=
self
.
storage
.
value
()
if
value
is
not
None
:
sizes
=
list
(
sizes
)
+
list
(
value
.
size
())[
1
:]
return
sizes
return
list
(
sparse_sizes
)
+
list
(
value
.
size
())[
1
:]
else
:
return
list
(
sparse_sizes
)
def
size
(
self
,
dim
:
int
)
->
int
:
return
self
.
sizes
()[
dim
]
...
...
@@ -268,7 +271,7 @@ class SparseTensor(object):
N
=
max
(
self
.
size
(
0
),
self
.
size
(
1
))
out
=
SparseTensor
(
row
=
row
,
rowptr
=
None
,
col
=
col
,
value
=
value
,
sparse_sizes
=
torch
.
Size
([
N
,
N
]
),
is_sorted
=
False
)
sparse_sizes
=
(
N
,
N
),
is_sorted
=
False
)
out
=
out
.
coalesce
(
reduce
)
return
out
...
...
torch_sparse/transpose.py
View file @
6e9cd9d6
...
...
@@ -19,7 +19,7 @@ def t(src: SparseTensor) -> SparseTensor:
rowptr
=
src
.
storage
.
_colptr
,
col
=
row
[
csr2csc
],
value
=
value
,
sparse_sizes
=
torch
.
Size
([
sparse_sizes
[
1
],
sparse_sizes
[
0
]
]
),
sparse_sizes
=
(
sparse_sizes
[
1
],
sparse_sizes
[
0
]),
rowcount
=
src
.
storage
.
_colcount
,
colptr
=
src
.
storage
.
_rowptr
,
colcount
=
src
.
storage
.
_rowcount
,
...
...
@@ -53,7 +53,7 @@ def transpose(index, value, m, n, coalesced=True):
row
,
col
=
col
,
row
if
coalesced
:
sparse_sizes
=
torch
.
Size
([
n
,
m
]
)
sparse_sizes
=
(
n
,
m
)
storage
=
SparseStorage
(
row
=
row
,
col
=
col
,
value
=
value
,
sparse_sizes
=
sparse_sizes
,
is_sorted
=
False
)
storage
=
storage
.
coalesce
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment