Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
ac25b416
Commit
ac25b416
authored
Jan 26, 2020
by
rusty1s
Browse files
fixed tests
parent
a1c268a5
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
125 additions
and
110 deletions
+125
-110
test/test_add.py
test/test_add.py
+6
-6
test/test_cat.py
test/test_cat.py
+15
-8
test/test_diag.py
test/test_diag.py
+9
-13
test/test_eye.py
test/test_eye.py
+17
-11
test/test_storage.py
test/test_storage.py
+35
-31
torch_sparse/cat.py
torch_sparse/cat.py
+34
-33
torch_sparse/diag.py
torch_sparse/diag.py
+3
-3
torch_sparse/storage.py
torch_sparse/storage.py
+6
-5
No files found.
test/test_add.py
View file @
ac25b416
...
...
@@ -22,21 +22,21 @@ def test_sparse_add(dtype, device):
mat1
=
mat
[:,
0
:
100000
]
mat2
=
mat
[:,
100000
:
200000
]
print
(
mat1
.
shape
)
print
(
mat2
.
shape
)
#
print(mat1.shape)
#
print(mat2.shape)
# 0.0159 to beat
t
=
time
.
perf_counter
()
mat
=
sparse_add
(
mat1
,
mat2
)
print
(
time
.
perf_counter
()
-
t
)
print
(
mat
.
nnz
())
#
print(time.perf_counter() - t)
#
print(mat.nnz())
mat1
=
mat_scipy
[:,
0
:
100000
]
mat2
=
mat_scipy
[:,
100000
:
200000
]
t
=
time
.
perf_counter
()
mat
=
mat1
+
mat2
print
(
time
.
perf_counter
()
-
t
)
print
(
mat
.
nnz
)
#
print(time.perf_counter() - t)
#
print(mat.nnz)
# mat1 + mat2
...
...
test/test_cat.py
View file @
ac25b416
...
...
@@ -8,24 +8,27 @@ from .utils import devices, tensor
@
pytest
.
mark
.
parametrize
(
'device'
,
devices
)
def
test_cat
(
device
):
index
=
tensor
([[
0
,
0
,
1
],
[
0
,
1
,
2
]],
torch
.
long
,
device
)
mat1
=
SparseTensor
(
index
)
row
,
col
=
tensor
([[
0
,
0
,
1
],
[
0
,
1
,
2
]],
torch
.
long
,
device
)
mat1
=
SparseTensor
(
row
=
row
,
col
=
col
)
mat1
.
fill_cache_
()
index
=
tensor
([[
0
,
0
,
1
,
2
],
[
0
,
1
,
1
,
0
]],
torch
.
long
,
device
)
mat2
=
SparseTensor
(
index
)
row
,
col
=
tensor
([[
0
,
0
,
1
,
2
],
[
0
,
1
,
1
,
0
]],
torch
.
long
,
device
)
mat2
=
SparseTensor
(
row
=
row
,
col
=
col
)
mat2
.
fill_cache_
()
out
=
cat
([
mat1
,
mat2
],
dim
=
0
)
assert
out
.
to_dense
().
tolist
()
==
[[
1
,
1
,
0
],
[
0
,
0
,
1
],
[
1
,
1
,
0
],
[
0
,
1
,
0
],
[
1
,
0
,
0
]]
assert
len
(
out
.
storage
.
cached_keys
())
==
2
assert
out
.
storage
.
has_rowcount
()
assert
out
.
storage
.
has_row
()
assert
out
.
storage
.
has_rowptr
()
assert
len
(
out
.
storage
.
cached_keys
())
==
1
assert
out
.
storage
.
has_rowcount
()
out
=
cat
([
mat1
,
mat2
],
dim
=
1
)
assert
out
.
to_dense
().
tolist
()
==
[[
1
,
1
,
0
,
1
,
1
],
[
0
,
0
,
1
,
0
,
1
],
[
0
,
0
,
0
,
1
,
0
]]
assert
out
.
storage
.
has_row
()
assert
not
out
.
storage
.
has_rowptr
()
assert
len
(
out
.
storage
.
cached_keys
())
==
2
assert
out
.
storage
.
has_colcount
()
assert
out
.
storage
.
has_colptr
()
...
...
@@ -34,9 +37,13 @@ def test_cat(device):
assert
out
.
to_dense
().
tolist
()
==
[[
1
,
1
,
0
,
0
,
0
],
[
0
,
0
,
1
,
0
,
0
],
[
0
,
0
,
0
,
1
,
1
],
[
0
,
0
,
0
,
0
,
1
],
[
0
,
0
,
0
,
1
,
0
]]
assert
len
(
out
.
storage
.
cached_keys
())
==
6
assert
out
.
storage
.
has_row
()
assert
out
.
storage
.
has_rowptr
()
assert
len
(
out
.
storage
.
cached_keys
())
==
5
mat1
.
set_value_
(
torch
.
randn
((
mat1
.
nnz
(),
4
),
device
=
device
))
out
=
cat
([
mat1
,
mat1
],
dim
=-
1
)
assert
out
.
storage
.
value
.
size
()
==
(
mat1
.
nnz
(),
8
)
assert
len
(
out
.
storage
.
cached_keys
())
==
6
assert
out
.
storage
.
has_row
()
assert
out
.
storage
.
has_rowptr
()
assert
len
(
out
.
storage
.
cached_keys
())
==
5
test/test_diag.py
View file @
ac25b416
...
...
@@ -9,26 +9,25 @@ from .utils import dtypes, devices, tensor
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
dtypes
,
devices
))
def
test_remove_diag
(
dtype
,
device
):
index
=
tensor
([
[
0
,
0
,
1
,
2
],
[
0
,
1
,
2
,
2
],
],
torch
.
long
,
device
)
row
,
col
=
tensor
([[
0
,
0
,
1
,
2
],
[
0
,
1
,
2
,
2
]],
torch
.
long
,
device
)
value
=
tensor
([
1
,
2
,
3
,
4
],
dtype
,
device
)
mat
=
SparseTensor
(
index
,
value
)
mat
=
SparseTensor
(
row
=
row
,
col
=
col
,
value
=
value
)
mat
.
fill_cache_
()
mat
=
mat
.
remove_diag
()
assert
mat
.
storage
.
index
.
tolist
()
==
[[
0
,
1
],
[
1
,
2
]]
assert
mat
.
storage
.
row
.
tolist
()
==
[
0
,
1
]
assert
mat
.
storage
.
col
.
tolist
()
==
[
1
,
2
]
assert
mat
.
storage
.
value
.
tolist
()
==
[
2
,
3
]
assert
len
(
mat
.
cached_keys
())
==
2
assert
mat
.
storage
.
rowcount
.
tolist
()
==
[
1
,
1
,
0
]
assert
mat
.
storage
.
colcount
.
tolist
()
==
[
0
,
1
,
1
]
mat
=
SparseTensor
(
index
,
value
)
mat
=
SparseTensor
(
row
=
row
,
col
=
col
,
value
=
value
)
mat
.
fill_cache_
()
mat
=
mat
.
remove_diag
(
k
=
1
)
assert
mat
.
storage
.
index
.
tolist
()
==
[[
0
,
2
],
[
0
,
2
]]
assert
mat
.
storage
.
row
.
tolist
()
==
[
0
,
2
]
assert
mat
.
storage
.
col
.
tolist
()
==
[
0
,
2
]
assert
mat
.
storage
.
value
.
tolist
()
==
[
1
,
4
]
assert
len
(
mat
.
cached_keys
())
==
2
assert
mat
.
storage
.
rowcount
.
tolist
()
==
[
1
,
0
,
1
]
...
...
@@ -37,12 +36,9 @@ def test_remove_diag(dtype, device):
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
dtypes
,
devices
))
def
test_set_diag
(
dtype
,
device
):
index
=
tensor
([
[
0
,
0
,
9
,
9
],
[
0
,
1
,
0
,
1
],
],
torch
.
long
,
device
)
row
,
col
=
tensor
([[
0
,
0
,
9
,
9
],
[
0
,
1
,
0
,
1
]],
torch
.
long
,
device
)
value
=
tensor
([
1
,
2
,
3
,
4
],
dtype
,
device
)
mat
=
SparseTensor
(
index
,
value
)
mat
=
SparseTensor
(
row
=
row
,
col
=
col
,
value
=
value
)
k
=
-
8
mat
=
mat
.
set_diag
(
k
)
test/test_eye.py
View file @
ac25b416
...
...
@@ -9,31 +9,37 @@ from .utils import dtypes, devices
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
dtypes
,
devices
))
def
test_eye
(
dtype
,
device
):
mat
=
SparseTensor
.
eye
(
3
,
dtype
=
dtype
,
device
=
device
)
assert
mat
.
storage
.
index
.
tolist
()
==
[[
0
,
1
,
2
],
[
0
,
1
,
2
]]
assert
mat
.
storage
.
row
.
tolist
()
==
[
0
,
1
,
2
]
assert
mat
.
storage
.
rowptr
.
tolist
()
==
[
0
,
1
,
2
,
3
]
assert
mat
.
storage
.
col
.
tolist
()
==
[
0
,
1
,
2
]
assert
mat
.
storage
.
value
.
tolist
()
==
[
1
,
1
,
1
]
assert
len
(
mat
.
cached_keys
())
==
0
mat
=
SparseTensor
.
eye
(
3
,
dtype
=
dtype
,
device
=
device
,
no_value
=
True
)
assert
mat
.
storage
.
index
.
tolist
()
==
[[
0
,
1
,
2
],
[
0
,
1
,
2
]]
mat
=
SparseTensor
.
eye
(
3
,
dtype
=
dtype
,
device
=
device
,
has_value
=
False
)
assert
mat
.
storage
.
row
.
tolist
()
==
[
0
,
1
,
2
]
assert
mat
.
storage
.
rowptr
.
tolist
()
==
[
0
,
1
,
2
,
3
]
assert
mat
.
storage
.
col
.
tolist
()
==
[
0
,
1
,
2
]
assert
mat
.
storage
.
value
is
None
assert
len
(
mat
.
cached_keys
())
==
0
mat
=
SparseTensor
.
eye
(
3
,
4
,
dtype
=
dtype
,
device
=
device
,
fill_cache
=
True
)
assert
mat
.
storage
.
index
.
tolist
()
==
[[
0
,
1
,
2
],
[
0
,
1
,
2
]]
assert
len
(
mat
.
cached_keys
())
==
6
assert
mat
.
storage
.
rowcount
.
tolist
()
==
[
1
,
1
,
1
]
assert
mat
.
storage
.
row
.
tolist
()
==
[
0
,
1
,
2
]
assert
mat
.
storage
.
rowptr
.
tolist
()
==
[
0
,
1
,
2
,
3
]
assert
mat
.
storage
.
colcount
.
tolist
()
==
[
1
,
1
,
1
,
0
]
assert
mat
.
storage
.
col
.
tolist
()
==
[
0
,
1
,
2
]
assert
len
(
mat
.
cached_keys
())
==
5
assert
mat
.
storage
.
rowcount
.
tolist
()
==
[
1
,
1
,
1
]
assert
mat
.
storage
.
colptr
.
tolist
()
==
[
0
,
1
,
2
,
3
,
3
]
assert
mat
.
storage
.
colcount
.
tolist
()
==
[
1
,
1
,
1
,
0
]
assert
mat
.
storage
.
csr2csc
.
tolist
()
==
[
0
,
1
,
2
]
assert
mat
.
storage
.
csc2csr
.
tolist
()
==
[
0
,
1
,
2
]
mat
=
SparseTensor
.
eye
(
4
,
3
,
dtype
=
dtype
,
device
=
device
,
fill_cache
=
True
)
assert
mat
.
storage
.
index
.
tolist
()
==
[[
0
,
1
,
2
],
[
0
,
1
,
2
]]
assert
len
(
mat
.
cached_keys
())
==
6
assert
mat
.
storage
.
rowcount
.
tolist
()
==
[
1
,
1
,
1
,
0
]
assert
mat
.
storage
.
row
.
tolist
()
==
[
0
,
1
,
2
]
assert
mat
.
storage
.
rowptr
.
tolist
()
==
[
0
,
1
,
2
,
3
,
3
]
assert
mat
.
storage
.
colcount
.
tolist
()
==
[
1
,
1
,
1
]
assert
mat
.
storage
.
col
.
tolist
()
==
[
0
,
1
,
2
]
assert
len
(
mat
.
cached_keys
())
==
5
assert
mat
.
storage
.
rowcount
.
tolist
()
==
[
1
,
1
,
1
,
0
]
assert
mat
.
storage
.
colptr
.
tolist
()
==
[
0
,
1
,
2
,
3
]
assert
mat
.
storage
.
colcount
.
tolist
()
==
[
1
,
1
,
1
]
assert
mat
.
storage
.
csr2csc
.
tolist
()
==
[
0
,
1
,
2
]
assert
mat
.
storage
.
csc2csr
.
tolist
()
==
[
0
,
1
,
2
]
test/test_storage.py
View file @
ac25b416
...
...
@@ -10,31 +10,30 @@ from .utils import dtypes, devices, tensor
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
dtypes
,
devices
))
def
test_storage
(
dtype
,
device
):
index
=
tensor
([[
0
,
0
,
1
,
1
],
[
0
,
1
,
0
,
1
]],
torch
.
long
,
device
)
row
,
col
=
tensor
([[
0
,
0
,
1
,
1
],
[
0
,
1
,
0
,
1
]],
torch
.
long
,
device
)
storage
=
SparseStorage
(
index
)
assert
storage
.
index
.
tolist
()
==
index
.
tolist
()
storage
=
SparseStorage
(
row
=
row
,
col
=
col
)
assert
storage
.
row
.
tolist
()
==
[
0
,
0
,
1
,
1
]
assert
storage
.
col
.
tolist
()
==
[
0
,
1
,
0
,
1
]
assert
storage
.
value
is
None
assert
storage
.
sparse_size
()
==
(
2
,
2
)
assert
storage
.
sparse_size
==
(
2
,
2
)
index
=
tensor
([[
0
,
0
,
1
,
1
],
[
1
,
0
,
1
,
0
]],
torch
.
long
,
device
)
row
,
col
=
tensor
([[
0
,
0
,
1
,
1
],
[
1
,
0
,
1
,
0
]],
torch
.
long
,
device
)
value
=
tensor
([
2
,
1
,
4
,
3
],
dtype
,
device
)
storage
=
SparseStorage
(
index
,
value
)
assert
storage
.
index
.
tolist
()
==
[[
0
,
0
,
1
,
1
],
[
0
,
1
,
0
,
1
]]
storage
=
SparseStorage
(
row
=
row
,
col
=
col
,
value
=
value
)
assert
storage
.
row
.
tolist
()
==
[
0
,
0
,
1
,
1
]
assert
storage
.
col
.
tolist
()
==
[
0
,
1
,
0
,
1
]
assert
storage
.
value
.
tolist
()
==
[
1
,
2
,
3
,
4
]
assert
storage
.
sparse_size
()
==
(
2
,
2
)
assert
storage
.
sparse_size
==
(
2
,
2
)
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
dtypes
,
devices
))
def
test_caching
(
dtype
,
device
):
index
=
tensor
([[
0
,
0
,
1
,
1
],
[
0
,
1
,
0
,
1
]],
torch
.
long
,
device
)
storage
=
SparseStorage
(
index
)
row
,
col
=
tensor
([[
0
,
0
,
1
,
1
],
[
0
,
1
,
0
,
1
]],
torch
.
long
,
device
)
storage
=
SparseStorage
(
row
=
row
,
col
=
col
)
assert
storage
.
_index
.
tolist
()
==
index
.
tolist
()
assert
storage
.
_row
.
tolist
()
==
row
.
tolist
()
assert
storage
.
_col
.
tolist
()
==
col
.
tolist
()
assert
storage
.
_value
is
None
assert
storage
.
_rowcount
is
None
...
...
@@ -52,12 +51,15 @@ def test_caching(dtype, device):
assert
storage
.
_csr2csc
.
tolist
()
==
[
0
,
2
,
1
,
3
]
assert
storage
.
_csc2csr
.
tolist
()
==
[
0
,
2
,
1
,
3
]
assert
storage
.
cached_keys
()
==
[
'rowcount'
,
'
row
ptr'
,
'colcount'
,
'colptr'
,
'csr2csc'
,
'csc2csr'
'rowcount'
,
'
col
ptr'
,
'colcount'
,
'csr2csc'
,
'csc2csr'
]
storage
=
SparseStorage
(
index
,
storage
.
value
,
storage
.
sparse_size
(),
storage
.
rowcount
,
storage
.
rowptr
,
storage
.
colcount
,
storage
.
colptr
,
storage
.
csr2csc
,
storage
.
csc2csr
)
storage
=
SparseStorage
(
row
=
row
,
rowptr
=
storage
.
rowptr
,
col
=
col
,
value
=
storage
.
value
,
sparse_size
=
storage
.
sparse_size
,
rowcount
=
storage
.
rowcount
,
colptr
=
storage
.
colptr
,
colcount
=
storage
.
colcount
,
csr2csc
=
storage
.
csr2csc
,
csc2csr
=
storage
.
csc2csr
)
assert
storage
.
_rowcount
.
tolist
()
==
[
2
,
2
]
assert
storage
.
_rowptr
.
tolist
()
==
[
0
,
2
,
4
]
...
...
@@ -66,12 +68,12 @@ def test_caching(dtype, device):
assert
storage
.
_csr2csc
.
tolist
()
==
[
0
,
2
,
1
,
3
]
assert
storage
.
_csc2csr
.
tolist
()
==
[
0
,
2
,
1
,
3
]
assert
storage
.
cached_keys
()
==
[
'rowcount'
,
'
row
ptr'
,
'colcount'
,
'colptr'
,
'csr2csc'
,
'csc2csr'
'rowcount'
,
'
col
ptr'
,
'colcount'
,
'csr2csc'
,
'csc2csr'
]
storage
.
clear_cache_
()
assert
storage
.
_rowcount
is
None
assert
storage
.
_rowptr
is
None
assert
storage
.
_rowptr
is
not
None
assert
storage
.
_colcount
is
None
assert
storage
.
_colptr
is
None
assert
storage
.
_csr2csc
is
None
...
...
@@ -91,9 +93,9 @@ def test_caching(dtype, device):
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
dtypes
,
devices
))
def
test_utility
(
dtype
,
device
):
index
=
tensor
([[
0
,
0
,
1
,
1
],
[
1
,
0
,
1
,
0
]],
torch
.
long
,
device
)
row
,
col
=
tensor
([[
0
,
0
,
1
,
1
],
[
1
,
0
,
1
,
0
]],
torch
.
long
,
device
)
value
=
tensor
([
1
,
2
,
3
,
4
],
dtype
,
device
)
storage
=
SparseStorage
(
index
,
value
)
storage
=
SparseStorage
(
row
=
row
,
col
=
col
,
value
=
value
)
assert
storage
.
has_value
()
...
...
@@ -107,20 +109,20 @@ def test_utility(dtype, device):
storage
=
storage
.
set_value
(
value
,
layout
=
'coo'
)
assert
storage
.
value
.
tolist
()
==
[
1
,
2
,
3
,
4
]
storage
.
sparse_resize
_
(
3
,
3
)
assert
storage
.
sparse_size
()
==
(
3
,
3
)
storage
=
storage
.
sparse_resize
(
3
,
3
)
assert
storage
.
sparse_size
==
(
3
,
3
)
new_storage
=
copy
.
copy
(
storage
)
assert
new_storage
!=
storage
assert
new_storage
.
index
.
data_ptr
()
==
storage
.
index
.
data_ptr
()
assert
new_storage
.
col
.
data_ptr
()
==
storage
.
col
.
data_ptr
()
new_storage
=
storage
.
clone
()
assert
new_storage
!=
storage
assert
new_storage
.
index
.
data_ptr
()
!=
storage
.
index
.
data_ptr
()
assert
new_storage
.
col
.
data_ptr
()
!=
storage
.
col
.
data_ptr
()
new_storage
=
copy
.
deepcopy
(
storage
)
assert
new_storage
!=
storage
assert
new_storage
.
index
.
data_ptr
()
!=
storage
.
index
.
data_ptr
()
assert
new_storage
.
col
.
data_ptr
()
!=
storage
.
col
.
data_ptr
()
storage
.
apply_value_
(
lambda
x
:
x
+
1
)
assert
storage
.
value
.
tolist
()
==
[
2
,
3
,
4
,
5
]
...
...
@@ -128,29 +130,31 @@ def test_utility(dtype, device):
assert
storage
.
value
.
tolist
()
==
[
3
,
4
,
5
,
6
]
storage
.
apply_
(
lambda
x
:
x
.
to
(
torch
.
long
))
assert
storage
.
index
.
dtype
==
torch
.
long
assert
storage
.
col
.
dtype
==
torch
.
long
assert
storage
.
value
.
dtype
==
torch
.
long
storage
=
storage
.
apply
(
lambda
x
:
x
.
to
(
torch
.
long
))
assert
storage
.
index
.
dtype
==
torch
.
long
assert
storage
.
col
.
dtype
==
torch
.
long
assert
storage
.
value
.
dtype
==
torch
.
long
storage
.
clear_cache_
()
assert
storage
.
map
(
lambda
x
:
x
.
numel
())
==
[
8
,
4
]
assert
storage
.
map
(
lambda
x
:
x
.
numel
())
==
[
4
,
4
,
4
]
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
dtypes
,
devices
))
def
test_coalesce
(
dtype
,
device
):
index
=
tensor
([[
0
,
0
,
0
,
1
,
1
],
[
0
,
1
,
1
,
0
,
1
]],
torch
.
long
,
device
)
row
,
col
=
tensor
([[
0
,
0
,
0
,
1
,
1
],
[
0
,
1
,
1
,
0
,
1
]],
torch
.
long
,
device
)
value
=
tensor
([
1
,
1
,
1
,
3
,
4
],
dtype
,
device
)
storage
=
SparseStorage
(
index
,
value
)
storage
=
SparseStorage
(
row
=
row
,
col
=
col
,
value
=
value
)
assert
storage
.
index
.
tolist
()
==
index
.
tolist
()
assert
storage
.
row
.
tolist
()
==
row
.
tolist
()
assert
storage
.
col
.
tolist
()
==
col
.
tolist
()
assert
storage
.
value
.
tolist
()
==
value
.
tolist
()
assert
not
storage
.
is_coalesced
()
storage
=
storage
.
coalesce
()
assert
storage
.
is_coalesced
()
assert
storage
.
index
.
tolist
()
==
[[
0
,
0
,
1
,
1
],
[
0
,
1
,
0
,
1
]]
assert
storage
.
row
.
tolist
()
==
[
0
,
0
,
1
,
1
]
assert
storage
.
col
.
tolist
()
==
[
0
,
1
,
0
,
1
]
assert
storage
.
value
.
tolist
()
==
[
1
,
2
,
3
,
4
]
torch_sparse/cat.py
View file @
ac25b416
...
...
@@ -3,17 +3,16 @@ import torch
def
cat
(
tensors
,
dim
):
assert
len
(
tensors
)
>
0
has_row
=
tensors
[
0
].
storage
.
has_row
()
has_value
=
tensors
[
0
].
has_value
()
has_rowcount
=
tensors
[
0
].
storage
.
has_rowcount
()
has_rowptr
=
tensors
[
0
].
storage
.
has_rowptr
()
has_colcount
=
tensors
[
0
].
storage
.
has_colcount
()
has_colptr
=
tensors
[
0
].
storage
.
has_colptr
()
has_colcount
=
tensors
[
0
].
storage
.
has_colcount
()
has_csr2csc
=
tensors
[
0
].
storage
.
has_csr2csc
()
has_csc2csr
=
tensors
[
0
].
storage
.
has_csc2csr
()
rows
,
cols
,
values
,
sparse_size
=
[],
[],
[],
[
0
,
0
]
rowcounts
,
rowptrs
,
colcounts
,
colptrs
=
[],
[],
[],
[]
csr2cscs
,
csc2csrs
,
nnzs
=
[],
[],
0
rows
,
rowptrs
,
cols
,
values
,
sparse_size
,
nnzs
=
[],
[],
[],
[],
[
0
,
0
],
0
rowcounts
,
colcounts
,
colptrs
,
csr2cscs
,
csc2csrs
=
[],
[],
[],
[],
[]
if
isinstance
(
dim
,
int
):
dim
=
tensors
[
0
].
dim
()
+
dim
if
dim
<
0
else
dim
...
...
@@ -22,29 +21,29 @@ def cat(tensors, dim):
if
dim
==
0
:
for
tensor
in
tensors
:
row
,
col
,
value
=
tensor
.
coo
()
rows
+=
[
row
+
sparse_size
[
0
]]
rowptr
,
col
,
value
=
tensor
.
csr
()
rowptr
=
rowptr
if
len
(
rowptrs
)
==
0
else
rowptr
[
1
:]
rowptrs
+=
[
rowptr
+
nnzs
]
cols
+=
[
col
]
values
+=
[
value
]
sparse_size
[
0
]
+=
tensor
.
sparse_size
(
0
)
sparse_size
[
1
]
=
max
(
sparse_size
[
1
],
tensor
.
sparse_size
(
1
))
if
has_row
:
rows
+=
[
tensor
.
storage
.
row
+
sparse_size
[
0
]]
if
has_rowcount
:
rowcounts
+=
[
tensor
.
storage
.
rowcount
]
if
has_rowptr
:
rowptr
=
tensor
.
storage
.
rowptr
rowptr
=
rowptr
if
len
(
rowptrs
)
==
0
else
rowptr
[
1
:]
rowptrs
+=
[
rowptr
+
nnzs
]
sparse_size
[
0
]
+=
tensor
.
sparse_size
(
0
)
sparse_size
[
1
]
=
max
(
sparse_size
[
1
],
tensor
.
sparse_size
(
1
))
nnzs
+=
tensor
.
nnz
()
storage
=
tensors
[
0
].
storage
.
__class__
(
torch
.
stack
([
torch
.
cat
(
rows
),
torch
.
cat
(
cols
)],
dim
=
0
),
row
=
torch
.
cat
(
rows
)
if
has_row
else
None
,
rowptr
=
torch
.
cat
(
rowptrs
),
col
=
torch
.
cat
(
cols
),
value
=
torch
.
cat
(
values
,
dim
=
0
)
if
has_value
else
None
,
sparse_size
=
sparse_size
,
rowcount
=
torch
.
cat
(
rowcounts
)
if
has_rowcount
else
None
,
rowptr
=
torch
.
cat
(
rowptrs
)
if
has_rowptr
else
None
,
is_sorted
=
True
)
is_sorted
=
True
)
elif
dim
==
1
:
for
tensor
in
tensors
:
...
...
@@ -52,8 +51,6 @@ def cat(tensors, dim):
rows
+=
[
row
]
cols
+=
[
col
+
sparse_size
[
1
]]
values
+=
[
value
]
sparse_size
[
0
]
=
max
(
sparse_size
[
0
],
tensor
.
sparse_size
(
0
))
sparse_size
[
1
]
+=
tensor
.
sparse_size
(
1
)
if
has_colcount
:
colcounts
+=
[
tensor
.
storage
.
colcount
]
...
...
@@ -63,10 +60,13 @@ def cat(tensors, dim):
colptr
=
colptr
if
len
(
colptrs
)
==
0
else
colptr
[
1
:]
colptrs
+=
[
colptr
+
nnzs
]
sparse_size
[
0
]
=
max
(
sparse_size
[
0
],
tensor
.
sparse_size
(
0
))
sparse_size
[
1
]
+=
tensor
.
sparse_size
(
1
)
nnzs
+=
tensor
.
nnz
()
storage
=
tensors
[
0
].
storage
.
__class__
(
torch
.
stack
([
torch
.
cat
(
rows
),
torch
.
cat
(
cols
)],
dim
=
0
),
row
=
torch
.
cat
(
rows
),
col
=
torch
.
cat
(
cols
),
value
=
torch
.
cat
(
values
,
dim
=
0
)
if
has_value
else
None
,
sparse_size
=
sparse_size
,
colcount
=
torch
.
cat
(
colcounts
)
if
has_colcount
else
None
,
...
...
@@ -76,21 +76,18 @@ def cat(tensors, dim):
elif
dim
==
(
0
,
1
)
or
dim
==
(
1
,
0
):
for
tensor
in
tensors
:
row
,
col
,
value
=
tensor
.
coo
()
rows
+=
[
row
+
sparse_size
[
0
]]
rowptr
,
col
,
value
=
tensor
.
csr
()
rowptr
=
rowptr
if
len
(
rowptrs
)
==
0
else
rowptr
[
1
:]
rowptrs
+=
[
rowptr
+
nnzs
]
cols
+=
[
col
+
sparse_size
[
1
]]
values
+=
[
value
]
if
has_value
else
[]
sparse_size
[
0
]
+=
tensor
.
sparse_size
(
0
)
sparse_size
[
1
]
+=
tensor
.
sparse_size
(
1
)
values
+=
[
value
]
if
has_row
:
rows
+=
[
tensor
.
storage
.
row
+
sparse_size
[
0
]]
if
has_rowcount
:
rowcounts
+=
[
tensor
.
storage
.
rowcount
]
if
has_rowptr
:
rowptr
=
tensor
.
storage
.
rowptr
rowptr
=
rowptr
if
len
(
rowptrs
)
==
0
else
rowptr
[
1
:]
rowptrs
+=
[
rowptr
+
nnzs
]
if
has_colcount
:
colcounts
+=
[
tensor
.
storage
.
colcount
]
...
...
@@ -105,16 +102,19 @@ def cat(tensors, dim):
if
has_csc2csr
:
csc2csrs
+=
[
tensor
.
storage
.
csc2csr
+
nnzs
]
sparse_size
[
0
]
+=
tensor
.
sparse_size
(
0
)
sparse_size
[
1
]
+=
tensor
.
sparse_size
(
1
)
nnzs
+=
tensor
.
nnz
()
storage
=
tensors
[
0
].
storage
.
__class__
(
torch
.
stack
([
torch
.
cat
(
rows
),
torch
.
cat
(
cols
)],
dim
=
0
),
row
=
torch
.
cat
(
rows
)
if
has_row
else
None
,
rowptr
=
torch
.
cat
(
rowptrs
),
col
=
torch
.
cat
(
cols
),
value
=
torch
.
cat
(
values
,
dim
=
0
)
if
has_value
else
None
,
sparse_size
=
sparse_size
,
rowcount
=
torch
.
cat
(
rowcounts
)
if
has_rowcount
else
None
,
rowptr
=
torch
.
cat
(
rowptrs
)
if
has_rowptr
else
None
,
colcount
=
torch
.
cat
(
colcounts
)
if
has_colcount
else
None
,
colptr
=
torch
.
cat
(
colptrs
)
if
has_colptr
else
None
,
colcount
=
torch
.
cat
(
colcounts
)
if
has_colcount
else
None
,
csr2csc
=
torch
.
cat
(
csr2cscs
)
if
has_csr2csc
else
None
,
csc2csr
=
torch
.
cat
(
csc2csrs
)
if
has_csc2csr
else
None
,
is_sorted
=
True
,
...
...
@@ -130,7 +130,8 @@ def cat(tensors, dim):
rowptr
=
old_storage
.
_rowptr
,
col
=
old_storage
.
_col
,
value
=
torch
.
cat
(
values
,
dim
=
dim
-
1
),
sparse_size
=
old_storage
.
sparse_size
(),
sparse_size
=
old_storage
.
sparse_size
,
rowcount
=
old_storage
.
_rowcount
,
colptr
=
old_storage
.
_colptr
,
colcount
=
old_storage
.
_colcount
,
csr2csc
=
old_storage
.
_csr2csc
,
...
...
torch_sparse/diag.py
View file @
ac25b416
...
...
@@ -11,7 +11,7 @@ except ImportError:
def
remove_diag
(
src
,
k
=
0
):
row
,
col
,
value
=
src
.
coo
()
inv_mask
=
row
!=
col
if
k
==
0
else
row
!=
(
col
-
k
)
row
,
col
=
row
[
inv_mask
],
col
[
inv_mask
]
new_
row
,
new_
col
=
row
[
inv_mask
],
col
[
inv_mask
]
if
src
.
has_value
():
value
=
value
[
inv_mask
]
...
...
@@ -29,7 +29,7 @@ def remove_diag(src, k=0):
colcount
=
src
.
storage
.
colcount
.
clone
()
colcount
[
col
[
mask
]]
-=
1
storage
=
src
.
storage
.
__class__
(
row
=
row
,
col
=
col
,
value
=
value
,
storage
=
src
.
storage
.
__class__
(
row
=
new_
row
,
col
=
new_
col
,
value
=
value
,
sparse_size
=
src
.
sparse_size
(),
rowcount
=
rowcount
,
colcount
=
colcount
,
is_sorted
=
True
)
...
...
@@ -61,7 +61,7 @@ def set_diag(src, values=None, k=0):
new_value
=
None
if
src
.
has_value
():
new_value
=
torch
.
new_empty
((
mask
.
size
(
0
),
)
+
value
.
size
()[
1
:])
new_value
=
value
.
new_empty
((
mask
.
size
(
0
),
)
+
value
.
size
()[
1
:])
new_value
[
mask
]
=
value
new_value
[
inv_mask
]
=
values
if
values
is
not
None
else
1
...
...
torch_sparse/storage.py
View file @
ac25b416
...
...
@@ -154,7 +154,7 @@ class SparseStorage(object):
idx
=
self
.
col
.
new_zeros
(
col
.
numel
()
+
1
)
idx
[
1
:]
=
sparse_size
[
1
]
*
self
.
row
+
self
.
col
if
(
idx
[
1
:]
<
idx
[:
-
1
]).
any
():
perm
=
idx
.
argsort
()
perm
=
idx
[
1
:]
.
argsort
()
self
.
_row
=
self
.
row
[
perm
]
self
.
_col
=
self
.
col
[
perm
]
self
.
_value
=
self
.
value
[
perm
]
if
self
.
has_value
()
else
None
...
...
@@ -313,12 +313,12 @@ class SparseStorage(object):
return
self
.
csr2csc
.
argsort
()
def
is_coalesced
(
self
):
idx
=
self
.
col
.
new_
zeros
(
self
.
col
.
numel
()
+
1
)
idx
=
self
.
col
.
new_
full
(
(
self
.
col
.
numel
()
+
1
,
),
-
1
)
idx
[
1
:]
=
self
.
sparse_size
[
1
]
*
self
.
row
+
self
.
col
return
(
idx
[
1
:]
>
idx
[:
-
1
]).
all
().
item
()
def
coalesce
(
self
,
reduce
=
'add'
):
idx
=
self
.
col
.
new_
zeros
(
self
.
col
.
numel
()
+
1
)
idx
=
self
.
col
.
new_
full
(
(
self
.
col
.
numel
()
+
1
,
),
-
1
)
idx
[
1
:]
=
self
.
sparse_size
[
1
]
*
self
.
row
+
self
.
col
mask
=
idx
[
1
:]
>
idx
[:
-
1
]
...
...
@@ -330,8 +330,9 @@ class SparseStorage(object):
value
=
self
.
value
if
self
.
has_value
():
idx
=
mask
.
cumsum
(
0
).
sub_
(
1
)
value
=
segment_csr
(
idx
,
value
,
reduce
=
reduce
)
ptr
=
mask
.
nonzero
().
flatten
()
ptr
=
torch
.
cat
([
ptr
,
ptr
.
new_full
((
1
,
),
value
.
size
(
0
))])
value
=
segment_csr
(
value
,
ptr
,
reduce
=
reduce
)
value
=
value
[
0
]
if
isinstance
(
value
,
tuple
)
else
value
return
self
.
__class__
(
row
=
row
,
col
=
col
,
value
=
value
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment