Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
925f9567
Commit
925f9567
authored
Feb 03, 2020
by
rusty1s
Browse files
cat and dim fix
parent
26aee002
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
251 additions
and
320 deletions
+251
-320
test/test_add.py
test/test_add.py
+0
-65
test/test_cat.py
test/test_cat.py
+8
-10
test/test_index_select.py
test/test_index_select.py
+0
-27
test/test_rowptr.py
test/test_rowptr.py
+0
-42
test/test_storage.py
test/test_storage.py
+1
-12
torch_sparse/__init__.py
torch_sparse/__init__.py
+1
-15
torch_sparse/cat.py
torch_sparse/cat.py
+224
-129
torch_sparse/storage.py
torch_sparse/storage.py
+16
-2
torch_sparse/tensor.py
torch_sparse/tensor.py
+1
-1
torch_sparse/unique.py
torch_sparse/unique.py
+0
-17
No files found.
test/test_add.py
deleted
100644 → 0
View file @
26aee002
import
time
from
itertools
import
product
from
scipy.io
import
loadmat
import
numpy
as
np
import
pytest
import
torch
from
torch_sparse.tensor
import
SparseTensor
from
torch_sparse.add
import
sparse_add
from
.utils
import
dtypes
,
devices
,
tensor
devices
=
[
'cpu'
]
dtypes
=
[
torch
.
float
]
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
dtypes
,
devices
))
def
test_sparse_add
(
dtype
,
device
):
name
=
(
'DIMACS10'
,
'citationCiteseer'
)[
1
]
mat_scipy
=
loadmat
(
f
'benchmark/
{
name
}
.mat'
)[
'Problem'
][
0
][
0
][
2
].
tocsr
()
mat
=
SparseTensor
.
from_scipy
(
mat_scipy
)
mat1
=
mat
[:,
0
:
100000
]
mat2
=
mat
[:,
100000
:
200000
]
# print(mat1.shape)
# print(mat2.shape)
# 0.0159 to beat
t
=
time
.
perf_counter
()
mat
=
sparse_add
(
mat1
,
mat2
)
# print(time.perf_counter() - t)
# print(mat.nnz())
mat1
=
mat_scipy
[:,
0
:
100000
]
mat2
=
mat_scipy
[:,
100000
:
200000
]
t
=
time
.
perf_counter
()
mat
=
mat1
+
mat2
# print(time.perf_counter() - t)
# print(mat.nnz)
# mat1 + mat2
# mat1 = mat1.tocoo()
# mat2 = mat2.tocoo()
# row1, col1 = mat1.row, mat1.col
# row2, col2 = mat2.row, mat2.col
# idx1 = row1 * 100000 + col1
# idx2 = row2 * 100000 + col2
# t = time.perf_counter()
# np.union1d(idx1, idx2)
# print(time.perf_counter() - t)
# index = tensor([[0, 0, 1], [0, 1, 2]], torch.long, device)
# mat1 = SparseTensor(index)
# print()
# print(mat1.to_dense())
# index = tensor([[0, 0, 1, 2], [0, 1, 1, 0]], torch.long, device)
# mat2 = SparseTensor(index)
# print(mat2.to_dense())
# add(mat1, mat2)
test/test_cat.py
View file @
925f9567
import
pytest
import
pytest
import
torch
import
torch
from
torch_sparse.tensor
import
SparseTensor
from
torch_sparse.tensor
import
SparseTensor
from
torch_sparse.cat
import
cat
from
torch_sparse.cat
import
cat
,
cat_diag
from
.utils
import
devices
,
tensor
from
.utils
import
devices
,
tensor
...
@@ -21,29 +21,27 @@ def test_cat(device):
...
@@ -21,29 +21,27 @@ def test_cat(device):
[
0
,
1
,
0
],
[
1
,
0
,
0
]]
[
0
,
1
,
0
],
[
1
,
0
,
0
]]
assert
out
.
storage
.
has_row
()
assert
out
.
storage
.
has_row
()
assert
out
.
storage
.
has_rowptr
()
assert
out
.
storage
.
has_rowptr
()
assert
len
(
out
.
storage
.
cached_keys
())
==
1
assert
out
.
storage
.
has_rowcount
()
assert
out
.
storage
.
has_rowcount
()
assert
out
.
storage
.
num_cached_keys
()
==
1
out
=
cat
([
mat1
,
mat2
],
dim
=
1
)
out
=
cat
([
mat1
,
mat2
],
dim
=
1
)
assert
out
.
to_dense
().
tolist
()
==
[[
1
,
1
,
0
,
1
,
1
],
[
0
,
0
,
1
,
0
,
1
],
assert
out
.
to_dense
().
tolist
()
==
[[
1
,
1
,
0
,
1
,
1
],
[
0
,
0
,
1
,
0
,
1
],
[
0
,
0
,
0
,
1
,
0
]]
[
0
,
0
,
0
,
1
,
0
]]
assert
out
.
storage
.
has_row
()
assert
out
.
storage
.
has_row
()
assert
not
out
.
storage
.
has_rowptr
()
assert
not
out
.
storage
.
has_rowptr
()
assert
len
(
out
.
storage
.
cached_keys
())
==
2
assert
out
.
storage
.
num_cached_keys
()
==
2
assert
out
.
storage
.
has_colcount
()
assert
out
.
storage
.
has_colptr
()
out
=
cat
([
mat1
,
mat2
]
,
dim
=
(
0
,
1
)
)
out
=
cat
_diag
([
mat1
,
mat2
])
assert
out
.
to_dense
().
tolist
()
==
[[
1
,
1
,
0
,
0
,
0
],
[
0
,
0
,
1
,
0
,
0
],
assert
out
.
to_dense
().
tolist
()
==
[[
1
,
1
,
0
,
0
,
0
],
[
0
,
0
,
1
,
0
,
0
],
[
0
,
0
,
0
,
1
,
1
],
[
0
,
0
,
0
,
0
,
1
],
[
0
,
0
,
0
,
1
,
1
],
[
0
,
0
,
0
,
0
,
1
],
[
0
,
0
,
0
,
1
,
0
]]
[
0
,
0
,
0
,
1
,
0
]]
assert
out
.
storage
.
has_row
()
assert
out
.
storage
.
has_row
()
assert
out
.
storage
.
has_rowptr
()
assert
out
.
storage
.
has_rowptr
()
assert
len
(
out
.
storage
.
cached_keys
()
)
==
5
assert
out
.
storage
.
num_
cached_keys
()
==
5
mat1
.
set_value_
(
torch
.
randn
((
mat1
.
nnz
(),
4
),
device
=
device
))
mat1
=
mat1
.
set_value_
(
torch
.
randn
((
mat1
.
nnz
(),
4
),
device
=
device
))
out
=
cat
([
mat1
,
mat1
],
dim
=-
1
)
out
=
cat
([
mat1
,
mat1
],
dim
=-
1
)
assert
out
.
storage
.
value
.
size
()
==
(
mat1
.
nnz
(),
8
)
assert
out
.
storage
.
value
()
.
size
()
==
(
mat1
.
nnz
(),
8
)
assert
out
.
storage
.
has_row
()
assert
out
.
storage
.
has_row
()
assert
out
.
storage
.
has_rowptr
()
assert
out
.
storage
.
has_rowptr
()
assert
len
(
out
.
storage
.
cached_keys
()
)
==
5
assert
out
.
storage
.
num_
cached_keys
()
==
5
test/test_index_select.py
deleted
100644 → 0
View file @
26aee002
import
time
from
itertools
import
product
from
scipy.io
import
loadmat
import
numpy
as
np
import
pytest
import
torch
from
torch_sparse.tensor
import
SparseTensor
from
torch_sparse.add
import
sparse_add
from
.utils
import
dtypes
,
devices
,
tensor
devices
=
[
'cpu'
]
dtypes
=
[
torch
.
float
]
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
dtypes
,
devices
))
def
test_index_select
(
dtype
,
device
):
row
=
torch
.
tensor
([
0
,
0
,
1
,
1
,
2
])
col
=
torch
.
tensor
([
0
,
1
,
1
,
2
,
1
])
mat
=
SparseTensor
(
row
=
row
,
col
=
col
)
print
()
print
(
mat
.
to_dense
())
pass
mat
=
mat
.
index_select
(
0
,
torch
.
tensor
([
0
,
2
]))
print
(
mat
.
to_dense
())
test/test_rowptr.py
deleted
100644 → 0
View file @
26aee002
from
itertools
import
product
import
pytest
import
torch
from
torch_sparse
import
rowptr_cpu
from
.utils
import
tensor
,
devices
if
torch
.
cuda
.
is_available
():
from
torch_sparse
import
rowptr_cuda
tests
=
[
{
'row'
:
[
0
,
0
,
1
,
1
,
1
,
2
,
2
],
'size'
:
5
,
'rowptr'
:
[
0
,
2
,
5
,
7
,
7
,
7
],
},
{
'row'
:
[
0
,
0
,
1
,
1
,
1
,
4
,
4
],
'size'
:
5
,
'rowptr'
:
[
0
,
2
,
5
,
5
,
5
,
7
],
},
{
'row'
:
[
2
,
2
,
4
,
4
],
'size'
:
7
,
'rowptr'
:
[
0
,
0
,
0
,
2
,
2
,
4
,
4
,
4
],
},
]
def
rowptr
(
row
,
size
):
return
(
rowptr_cuda
if
row
.
is_cuda
else
rowptr_cpu
).
rowptr
(
row
,
size
)
@
pytest
.
mark
.
parametrize
(
'test,device'
,
product
(
tests
,
devices
))
def
test_rowptr
(
test
,
device
):
row
=
tensor
(
test
[
'row'
],
torch
.
long
,
device
)
size
=
test
[
'size'
]
expected
=
tensor
(
test
[
'rowptr'
],
torch
.
long
,
device
)
out
=
rowptr
(
row
,
size
)
assert
torch
.
all
(
out
==
expected
)
test/test_storage.py
View file @
925f9567
...
@@ -3,7 +3,7 @@ from itertools import product
...
@@ -3,7 +3,7 @@ from itertools import product
import
pytest
import
pytest
import
torch
import
torch
from
torch_sparse.storage
import
SparseStorage
,
no_cache
from
torch_sparse.storage
import
SparseStorage
from
.utils
import
dtypes
,
devices
,
tensor
from
.utils
import
dtypes
,
devices
,
tensor
...
@@ -79,17 +79,6 @@ def test_caching(dtype, device):
...
@@ -79,17 +79,6 @@ def test_caching(dtype, device):
assert
storage
.
_csr2csc
is
None
assert
storage
.
_csr2csc
is
None
assert
storage
.
cached_keys
()
==
[]
assert
storage
.
cached_keys
()
==
[]
with
no_cache
():
storage
.
fill_cache_
()
assert
storage
.
cached_keys
()
==
[]
@
no_cache
()
def
do_something
(
storage
):
return
storage
.
fill_cache_
()
storage
=
do_something
(
storage
)
assert
storage
.
cached_keys
()
==
[]
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
dtypes
,
devices
))
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
dtypes
,
devices
))
def
test_utility
(
dtype
,
device
):
def
test_utility
(
dtype
,
device
):
...
...
torch_sparse/__init__.py
View file @
925f9567
import
torch
torch
.
ops
.
load_library
(
'torch_sparse/convert_cpu.so'
)
torch
.
ops
.
load_library
(
'torch_sparse/diag_cpu.so'
)
torch
.
ops
.
load_library
(
'torch_sparse/spmm_cpu.so'
)
try
:
torch
.
ops
.
load_library
(
'torch_sparse/convert_cuda.so'
)
torch
.
ops
.
load_library
(
'torch_sparse/diag_cuda.so'
)
torch
.
ops
.
load_library
(
'torch_sparse/spmm_cuda.so'
)
torch
.
ops
.
load_library
(
'torch_sparse/spspmm_cuda.so'
)
except
OSError
as
e
:
if
torch
.
cuda
.
is_available
():
raise
e
from
.convert
import
to_torch_sparse
,
from_torch_sparse
,
to_scipy
,
from_scipy
from
.convert
import
to_torch_sparse
,
from_torch_sparse
,
to_scipy
,
from_scipy
from
.coalesce
import
coalesce
from
.coalesce
import
coalesce
from
.transpose
import
transpose
from
.transpose
import
transpose
...
@@ -48,3 +33,4 @@ from .mul import mul, mul_, mul_nnz, mul_nnz_
...
@@ -48,3 +33,4 @@ from .mul import mul, mul_, mul_nnz, mul_nnz_
from
.reduce
import
sum
,
mean
,
min
,
max
from
.reduce
import
sum
,
mean
,
min
,
max
from
.matmul
import
(
spmm_sum
,
spmm_add
,
spmm_mean
,
spmm_min
,
spmm_max
,
spmm
,
from
.matmul
import
(
spmm_sum
,
spmm_add
,
spmm_mean
,
spmm_min
,
spmm_max
,
spmm
,
spspmm_sum
,
spspmm_add
,
spspmm
,
matmul
)
spspmm_sum
,
spspmm_add
,
spspmm
,
matmul
)
from
.cat
import
cat
,
cat_diag
torch_sparse/cat.py
View file @
925f9567
from
typing
import
List
,
Optional
import
torch
import
torch
from
torch_sparse.storage
import
SparseStorage
from
torch_sparse.tensor
import
SparseTensor
def
cat
(
tensors
,
dim
):
@
torch
.
jit
.
script
def
cat
(
tensors
:
List
[
SparseTensor
],
dim
:
int
)
->
SparseTensor
:
assert
len
(
tensors
)
>
0
assert
len
(
tensors
)
>
0
has_row
=
tensors
[
0
].
storage
.
has_row
()
if
dim
<
0
:
has_value
=
tensors
[
0
].
has_value
()
dim
=
tensors
[
0
].
dim
()
+
dim
has_rowcount
=
tensors
[
0
].
storage
.
has_rowcount
()
has_colptr
=
tensors
[
0
].
storage
.
has_colptr
()
has_colcount
=
tensors
[
0
].
storage
.
has_colcount
()
has_csr2csc
=
tensors
[
0
].
storage
.
has_csr2csc
()
has_csc2csr
=
tensors
[
0
].
storage
.
has_csc2csr
()
rows
,
rowptrs
,
cols
,
values
,
sparse_size
,
nnzs
=
[],
[],
[],
[],
[
0
,
0
],
0
rowcounts
,
colcounts
,
colptrs
,
csr2cscs
,
csc2csrs
=
[],
[],
[],
[],
[]
if
isinstance
(
dim
,
int
):
dim
=
tensors
[
0
].
dim
()
+
dim
if
dim
<
0
else
dim
else
:
dim
=
tuple
([
tensors
[
0
].
dim
()
+
d
if
d
<
0
else
d
for
d
in
dim
])
if
dim
==
0
:
if
dim
==
0
:
rows
:
List
[
torch
.
Tensor
]
=
[]
rowptrs
:
List
[
torch
.
Tensor
]
=
[]
cols
:
List
[
torch
.
Tensor
]
=
[]
values
:
List
[
torch
.
Tensor
]
=
[]
sparse_sizes
:
List
[
int
]
=
[
0
,
0
]
rowcounts
:
List
[
torch
.
Tensor
]
=
[]
nnz
:
int
=
0
for
tensor
in
tensors
:
for
tensor
in
tensors
:
rowptr
,
col
,
value
=
tensor
.
csr
()
row
=
tensor
.
storage
.
_row
rowptr
=
rowptr
if
len
(
rowptrs
)
==
0
else
rowptr
[
1
:]
if
row
is
not
None
:
rowptrs
+=
[
rowptr
+
nnzs
]
rows
.
append
(
row
+
sparse_sizes
[
0
])
cols
+=
[
col
]
values
+=
[
value
]
rowptr
=
tensor
.
storage
.
_rowptr
if
rowptr
is
not
None
:
if
has_row
:
if
len
(
rowptrs
)
>
0
:
rows
+=
[
tensor
.
storage
.
row
+
sparse_size
[
0
]]
rowptr
=
rowptr
[
1
:]
rowptrs
.
append
(
rowptr
+
nnz
)
if
has_rowcount
:
rowcounts
+=
[
tensor
.
storage
.
rowcount
]
cols
.
append
(
tensor
.
storage
.
_col
)
sparse_size
[
0
]
+=
tensor
.
sparse_size
(
0
)
value
=
tensor
.
storage
.
_value
sparse_size
[
1
]
=
max
(
sparse_size
[
1
],
tensor
.
sparse_size
(
1
))
if
value
is
not
None
:
nnzs
+=
tensor
.
nnz
()
values
.
append
(
value
)
storage
=
tensors
[
0
].
storage
.
__class__
(
rowcount
=
tensor
.
storage
.
_rowcount
row
=
torch
.
cat
(
rows
)
if
has_row
else
None
,
if
rowcount
is
not
None
:
rowptr
=
torch
.
cat
(
rowptrs
),
col
=
torch
.
cat
(
cols
),
rowcounts
.
append
(
rowcount
)
value
=
torch
.
cat
(
values
,
dim
=
0
)
if
has_value
else
None
,
sparse_size
=
sparse_size
,
sparse_sizes
[
0
]
+=
tensor
.
sparse_size
(
0
)
rowcount
=
torch
.
cat
(
rowcounts
)
if
has_rowcount
else
None
,
sparse_sizes
[
1
]
=
max
(
sparse_sizes
[
1
],
tensor
.
sparse_size
(
1
))
is_sorted
=
True
)
nnz
+=
tensor
.
nnz
()
row
:
Optional
[
torch
.
Tensor
]
=
None
if
len
(
rows
)
==
len
(
tensors
):
row
=
torch
.
cat
(
rows
,
dim
=
0
)
rowptr
:
Optional
[
torch
.
Tensor
]
=
None
if
len
(
rowptrs
)
==
len
(
tensors
):
rowptr
=
torch
.
cat
(
rowptrs
,
dim
=
0
)
col
=
torch
.
cat
(
cols
,
dim
=
0
)
value
:
Optional
[
torch
.
Tensor
]
=
None
if
len
(
values
)
==
len
(
tensors
):
value
=
torch
.
cat
(
values
,
dim
=
0
)
rowcount
:
Optional
[
torch
.
Tensor
]
=
None
if
len
(
rowcounts
)
==
len
(
tensors
):
rowcount
=
torch
.
cat
(
rowcounts
,
dim
=
0
)
storage
=
SparseStorage
(
row
=
row
,
rowptr
=
rowptr
,
col
=
col
,
value
=
value
,
sparse_sizes
=
sparse_sizes
,
rowcount
=
rowcount
,
colptr
=
None
,
colcount
=
None
,
csr2csc
=
None
,
csc2csr
=
None
,
is_sorted
=
True
)
return
tensors
[
0
].
from_storage
(
storage
)
elif
dim
==
1
:
elif
dim
==
1
:
rows
:
List
[
torch
.
Tensor
]
=
[]
cols
:
List
[
torch
.
Tensor
]
=
[]
values
:
List
[
torch
.
Tensor
]
=
[]
sparse_sizes
:
List
[
int
]
=
[
0
,
0
]
colptrs
:
List
[
torch
.
Tensor
]
=
[]
colcounts
:
List
[
torch
.
Tensor
]
=
[]
nnz
:
int
=
0
for
tensor
in
tensors
:
for
tensor
in
tensors
:
row
,
col
,
value
=
tensor
.
coo
()
row
,
col
,
value
=
tensor
.
coo
()
rows
+=
[
row
]
cols
+=
[
col
+
sparse_size
[
1
]]
rows
.
append
(
row
)
values
+=
[
value
]
cols
.
append
(
tensor
.
storage
.
_col
+
sparse_sizes
[
1
])
if
has_colcount
:
colcounts
+=
[
tensor
.
storage
.
colcount
]
if
value
is
not
None
:
values
.
append
(
value
)
if
has_colptr
:
colptr
=
tensor
.
storage
.
colptr
colptr
=
tensor
.
storage
.
_colptr
colptr
=
colptr
if
len
(
colptrs
)
==
0
else
colptr
[
1
:]
if
colptr
is
not
None
:
colptrs
+=
[
colptr
+
nnzs
]
if
len
(
colptrs
)
>
0
:
colptr
=
colptr
[
1
:]
sparse_size
[
0
]
=
max
(
sparse_size
[
0
],
tensor
.
sparse_size
(
0
))
colptrs
.
append
(
colptr
+
nnz
)
sparse_size
[
1
]
+=
tensor
.
sparse_size
(
1
)
nnzs
+=
tensor
.
nnz
()
colcount
=
tensor
.
storage
.
_colcount
if
colcount
is
not
None
:
storage
=
tensors
[
0
].
storage
.
__class__
(
colcounts
.
append
(
colcount
)
row
=
torch
.
cat
(
rows
),
col
=
torch
.
cat
(
cols
),
sparse_sizes
[
0
]
=
max
(
sparse_sizes
[
0
],
tensor
.
sparse_size
(
0
))
value
=
torch
.
cat
(
values
,
dim
=
0
)
if
has_value
else
None
,
sparse_sizes
[
1
]
+=
tensor
.
sparse_size
(
1
)
sparse_size
=
sparse_size
,
nnz
+=
tensor
.
nnz
()
colcount
=
torch
.
cat
(
colcounts
)
if
has_colcount
else
None
,
colptr
=
torch
.
cat
(
colptrs
)
if
has_colptr
else
None
,
row
=
torch
.
cat
(
rows
,
dim
=
0
)
is_sorted
=
False
,
)
col
=
torch
.
cat
(
cols
,
dim
=
0
)
elif
dim
==
(
0
,
1
)
or
dim
==
(
1
,
0
):
value
:
Optional
[
torch
.
Tensor
]
=
None
for
tensor
in
tensors
:
if
len
(
values
)
==
len
(
tensors
):
rowptr
,
col
,
value
=
tensor
.
csr
()
value
=
torch
.
cat
(
values
,
dim
=
0
)
rowptr
=
rowptr
if
len
(
rowptrs
)
==
0
else
rowptr
[
1
:]
rowptrs
+=
[
rowptr
+
nnzs
]
colptr
:
Optional
[
torch
.
Tensor
]
=
None
cols
+=
[
col
+
sparse_size
[
1
]]
if
len
(
colptrs
)
==
len
(
tensors
):
values
+=
[
value
]
colptr
=
torch
.
cat
(
colptrs
,
dim
=
0
)
if
has_row
:
colcount
:
Optional
[
torch
.
Tensor
]
=
None
rows
+=
[
tensor
.
storage
.
row
+
sparse_size
[
0
]]
if
len
(
colcounts
)
==
len
(
tensors
):
colcount
=
torch
.
cat
(
colcounts
,
dim
=
0
)
if
has_rowcount
:
rowcounts
+=
[
tensor
.
storage
.
rowcount
]
storage
=
SparseStorage
(
row
=
row
,
rowptr
=
None
,
col
=
col
,
value
=
value
,
sparse_sizes
=
sparse_sizes
,
rowcount
=
None
,
if
has_colcount
:
colptr
=
colptr
,
colcount
=
colcount
,
csr2csc
=
None
,
colcounts
+=
[
tensor
.
storage
.
colcount
]
csc2csr
=
None
,
is_sorted
=
False
)
return
tensors
[
0
].
from_storage
(
storage
)
if
has_colptr
:
colptr
=
tensor
.
storage
.
colptr
elif
dim
>
1
and
dim
<
tensors
[
0
].
dim
():
colptr
=
colptr
if
len
(
colptrs
)
==
0
else
colptr
[
1
:]
values
:
List
[
torch
.
Tensor
]
=
[]
colptrs
+=
[
colptr
+
nnzs
]
if
has_csr2csc
:
csr2cscs
+=
[
tensor
.
storage
.
csr2csc
+
nnzs
]
if
has_csc2csr
:
csc2csrs
+=
[
tensor
.
storage
.
csc2csr
+
nnzs
]
sparse_size
[
0
]
+=
tensor
.
sparse_size
(
0
)
sparse_size
[
1
]
+=
tensor
.
sparse_size
(
1
)
nnzs
+=
tensor
.
nnz
()
storage
=
tensors
[
0
].
storage
.
__class__
(
row
=
torch
.
cat
(
rows
)
if
has_row
else
None
,
rowptr
=
torch
.
cat
(
rowptrs
),
col
=
torch
.
cat
(
cols
),
value
=
torch
.
cat
(
values
,
dim
=
0
)
if
has_value
else
None
,
sparse_size
=
sparse_size
,
rowcount
=
torch
.
cat
(
rowcounts
)
if
has_rowcount
else
None
,
colptr
=
torch
.
cat
(
colptrs
)
if
has_colptr
else
None
,
colcount
=
torch
.
cat
(
colcounts
)
if
has_colcount
else
None
,
csr2csc
=
torch
.
cat
(
csr2cscs
)
if
has_csr2csc
else
None
,
csc2csr
=
torch
.
cat
(
csc2csrs
)
if
has_csc2csr
else
None
,
is_sorted
=
True
,
)
elif
isinstance
(
dim
,
int
)
and
dim
>
1
and
dim
<
tensors
[
0
].
dim
():
for
tensor
in
tensors
:
for
tensor
in
tensors
:
values
+=
[
tensor
.
storage
.
value
]
value
=
tensor
.
storage
.
value
()
if
value
is
not
None
:
old_storage
=
tensors
[
0
].
storage
values
.
append
(
value
)
storage
=
old_storage
.
__class__
(
row
=
old_storage
.
_row
,
value
:
Optional
[
torch
.
Tensor
]
=
None
rowptr
=
old_storage
.
_rowptr
,
if
len
(
values
)
==
len
(
tensors
):
col
=
old_storage
.
_col
,
value
=
torch
.
cat
(
values
,
dim
=
dim
-
1
)
value
=
torch
.
cat
(
values
,
dim
=
dim
-
1
),
sparse_size
=
old_storage
.
sparse_size
,
rowcount
=
old_storage
.
_rowcount
,
colptr
=
old_storage
.
_colptr
,
colcount
=
old_storage
.
_colcount
,
csr2csc
=
old_storage
.
_csr2csc
,
csc2csr
=
old_storage
.
_csc2csr
,
is_sorted
=
True
,
)
return
tensors
[
0
].
set_value
(
value
,
layout
=
'coo'
)
else
:
else
:
raise
IndexError
(
raise
IndexError
(
(
f
'Dimension out of range: Expected to be in range of '
(
f
'Dimension out of range: Expected to be in range of '
f
'[
{
-
tensors
[
0
].
dim
()
}
,
{
tensors
[
0
].
dim
()
-
1
}
, but got
{
dim
}
]'
))
f
'[
{
-
tensors
[
0
].
dim
()
}
,
{
tensors
[
0
].
dim
()
-
1
}
, but got
{
dim
}
]'
))
return
tensors
[
0
].
__class__
.
from_storage
(
storage
)
@
torch
.
jit
.
script
def
cat_diag
(
tensors
:
List
[
SparseTensor
])
->
SparseTensor
:
assert
len
(
tensors
)
>
0
rows
:
List
[
torch
.
Tensor
]
=
[]
rowptrs
:
List
[
torch
.
Tensor
]
=
[]
cols
:
List
[
torch
.
Tensor
]
=
[]
values
:
List
[
torch
.
Tensor
]
=
[]
sparse_sizes
:
List
[
int
]
=
[
0
,
0
]
rowcounts
:
List
[
torch
.
Tensor
]
=
[]
colptrs
:
List
[
torch
.
Tensor
]
=
[]
colcounts
:
List
[
torch
.
Tensor
]
=
[]
csr2cscs
:
List
[
torch
.
Tensor
]
=
[]
csc2csrs
:
List
[
torch
.
Tensor
]
=
[]
nnz
:
int
=
0
for
tensor
in
tensors
:
row
=
tensor
.
storage
.
_row
if
row
is
not
None
:
rows
.
append
(
row
+
sparse_sizes
[
0
])
rowptr
=
tensor
.
storage
.
_rowptr
if
rowptr
is
not
None
:
if
len
(
rowptrs
)
>
0
:
rowptr
=
rowptr
[
1
:]
rowptrs
.
append
(
rowptr
+
nnz
)
cols
.
append
(
tensor
.
storage
.
_col
+
sparse_sizes
[
1
])
value
=
tensor
.
storage
.
_value
if
value
is
not
None
:
values
.
append
(
value
)
rowcount
=
tensor
.
storage
.
_rowcount
if
rowcount
is
not
None
:
rowcounts
.
append
(
rowcount
)
colptr
=
tensor
.
storage
.
_colptr
if
colptr
is
not
None
:
if
len
(
colptrs
)
>
0
:
colptr
=
colptr
[
1
:]
colptrs
.
append
(
colptr
+
nnz
)
colcount
=
tensor
.
storage
.
_colcount
if
colcount
is
not
None
:
colcounts
.
append
(
colcount
)
csr2csc
=
tensor
.
storage
.
_csr2csc
if
csr2csc
is
not
None
:
csr2cscs
.
append
(
csr2csc
+
nnz
)
csc2csr
=
tensor
.
storage
.
_csc2csr
if
csc2csr
is
not
None
:
csc2csrs
.
append
(
csc2csr
+
nnz
)
sparse_sizes
[
0
]
+=
tensor
.
sparse_size
(
0
)
sparse_sizes
[
1
]
+=
tensor
.
sparse_size
(
1
)
nnz
+=
tensor
.
nnz
()
row
:
Optional
[
torch
.
Tensor
]
=
None
if
len
(
rows
)
==
len
(
tensors
):
row
=
torch
.
cat
(
rows
,
dim
=
0
)
rowptr
:
Optional
[
torch
.
Tensor
]
=
None
if
len
(
rowptrs
)
==
len
(
tensors
):
rowptr
=
torch
.
cat
(
rowptrs
,
dim
=
0
)
col
=
torch
.
cat
(
cols
,
dim
=
0
)
value
:
Optional
[
torch
.
Tensor
]
=
None
if
len
(
values
)
==
len
(
tensors
):
value
=
torch
.
cat
(
values
,
dim
=
0
)
rowcount
:
Optional
[
torch
.
Tensor
]
=
None
if
len
(
rowcounts
)
==
len
(
tensors
):
rowcount
=
torch
.
cat
(
rowcounts
,
dim
=
0
)
colptr
:
Optional
[
torch
.
Tensor
]
=
None
if
len
(
colptrs
)
==
len
(
tensors
):
colptr
=
torch
.
cat
(
colptrs
,
dim
=
0
)
colcount
:
Optional
[
torch
.
Tensor
]
=
None
if
len
(
colcounts
)
==
len
(
tensors
):
colcount
=
torch
.
cat
(
colcounts
,
dim
=
0
)
csr2csc
:
Optional
[
torch
.
Tensor
]
=
None
if
len
(
csr2cscs
)
==
len
(
tensors
):
csr2csc
=
torch
.
cat
(
csr2cscs
,
dim
=
0
)
csc2csr
:
Optional
[
torch
.
Tensor
]
=
None
if
len
(
csc2csrs
)
==
len
(
tensors
):
csc2csr
=
torch
.
cat
(
csc2csrs
,
dim
=
0
)
storage
=
SparseStorage
(
row
=
row
,
rowptr
=
rowptr
,
col
=
col
,
value
=
value
,
sparse_sizes
=
sparse_sizes
,
rowcount
=
rowcount
,
colptr
=
colptr
,
colcount
=
colcount
,
csr2csc
=
csr2csc
,
csc2csr
=
csc2csr
,
is_sorted
=
True
)
return
tensors
[
0
].
from_storage
(
storage
)
torch_sparse/storage.py
View file @
925f9567
...
@@ -199,7 +199,7 @@ class SparseStorage(object):
...
@@ -199,7 +199,7 @@ class SparseStorage(object):
def
set_value_
(
self
,
value
:
Optional
[
torch
.
Tensor
],
def
set_value_
(
self
,
value
:
Optional
[
torch
.
Tensor
],
layout
:
Optional
[
str
]
=
None
):
layout
:
Optional
[
str
]
=
None
):
if
value
is
not
None
:
if
value
is
not
None
:
if
get_layout
(
layout
)
==
'csc
2csr
'
:
if
get_layout
(
layout
)
==
'csc'
:
value
=
value
[
self
.
csc2csr
()]
value
=
value
[
self
.
csc2csr
()]
value
=
value
.
contiguous
()
value
=
value
.
contiguous
()
assert
value
.
device
==
self
.
_col
.
device
assert
value
.
device
==
self
.
_col
.
device
...
@@ -211,7 +211,7 @@ class SparseStorage(object):
...
@@ -211,7 +211,7 @@ class SparseStorage(object):
def
set_value
(
self
,
value
:
Optional
[
torch
.
Tensor
],
def
set_value
(
self
,
value
:
Optional
[
torch
.
Tensor
],
layout
:
Optional
[
str
]
=
None
):
layout
:
Optional
[
str
]
=
None
):
if
value
is
not
None
:
if
value
is
not
None
:
if
get_layout
(
layout
)
==
'csc
2csr
'
:
if
get_layout
(
layout
)
==
'csc'
:
value
=
value
[
self
.
csc2csr
()]
value
=
value
[
self
.
csc2csr
()]
value
=
value
.
contiguous
()
value
=
value
.
contiguous
()
assert
value
.
device
==
self
.
_col
.
device
assert
value
.
device
==
self
.
_col
.
device
...
@@ -384,6 +384,20 @@ class SparseStorage(object):
...
@@ -384,6 +384,20 @@ class SparseStorage(object):
self
.
_csc2csr
=
None
self
.
_csc2csr
=
None
return
self
return
self
def
num_cached_keys
(
self
)
->
int
:
count
=
0
if
self
.
has_rowcount
():
count
+=
1
if
self
.
has_colptr
():
count
+=
1
if
self
.
has_colcount
():
count
+=
1
if
self
.
has_csr2csc
():
count
+=
1
if
self
.
has_csc2csr
():
count
+=
1
return
count
def
copy
(
self
):
def
copy
(
self
):
return
SparseStorage
(
row
=
self
.
_row
,
rowptr
=
self
.
_rowptr
,
col
=
self
.
_col
,
return
SparseStorage
(
row
=
self
.
_row
,
rowptr
=
self
.
_rowptr
,
col
=
self
.
_col
,
value
=
self
.
_value
,
value
=
self
.
_value
,
...
...
torch_sparse/tensor.py
View file @
925f9567
...
@@ -197,7 +197,7 @@ class SparseTensor(object):
...
@@ -197,7 +197,7 @@ class SparseTensor(object):
sizes
=
self
.
sparse_sizes
()
sizes
=
self
.
sparse_sizes
()
value
=
self
.
storage
.
value
()
value
=
self
.
storage
.
value
()
if
value
is
not
None
:
if
value
is
not
None
:
sizes
+
=
value
.
size
()[
1
:]
sizes
=
sizes
+
value
.
size
()[
1
:]
return
sizes
return
sizes
def
size
(
self
,
dim
:
int
)
->
int
:
def
size
(
self
,
dim
:
int
)
->
int
:
...
...
torch_sparse/unique.py
deleted
100644 → 0
View file @
26aee002
import
torch
import
numpy
as
np
if
torch
.
cuda
.
is_available
():
import
torch_sparse.unique_cuda
def
unique
(
src
):
src
=
src
.
contiguous
().
view
(
-
1
)
if
src
.
is_cuda
:
out
,
perm
=
torch_sparse
.
unique_cuda
.
unique
(
src
)
else
:
out
,
perm
=
np
.
unique
(
src
.
numpy
(),
return_index
=
True
)
out
,
perm
=
torch
.
from_numpy
(
out
),
torch
.
from_numpy
(
perm
)
return
out
,
perm
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment