Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
872938af
Commit
872938af
authored
Jul 01, 2020
by
rusty1s
Browse files
overload for cat
parent
468aea5b
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
140 additions
and
120 deletions
+140
-120
test/test_cat.py
test/test_cat.py
+2
-2
torch_sparse/__init__.py
torch_sparse/__init__.py
+1
-2
torch_sparse/cat.py
torch_sparse/cat.py
+137
-116
No files found.
test/test_cat.py
View file @
872938af
import
pytest
import
pytest
import
torch
import
torch
from
torch_sparse.tensor
import
SparseTensor
from
torch_sparse.tensor
import
SparseTensor
from
torch_sparse.cat
import
cat
,
cat_diag
from
torch_sparse.cat
import
cat
from
.utils
import
devices
,
tensor
from
.utils
import
devices
,
tensor
...
@@ -31,7 +31,7 @@ def test_cat(device):
...
@@ -31,7 +31,7 @@ def test_cat(device):
assert
not
out
.
storage
.
has_rowptr
()
assert
not
out
.
storage
.
has_rowptr
()
assert
out
.
storage
.
num_cached_keys
()
==
2
assert
out
.
storage
.
num_cached_keys
()
==
2
out
=
cat
_diag
([
mat1
,
mat2
])
out
=
cat
([
mat1
,
mat2
]
,
dim
=
(
0
,
1
)
)
assert
out
.
to_dense
().
tolist
()
==
[[
1
,
1
,
0
,
0
,
0
],
[
0
,
0
,
1
,
0
,
0
],
assert
out
.
to_dense
().
tolist
()
==
[[
1
,
1
,
0
,
0
,
0
],
[
0
,
0
,
1
,
0
,
0
],
[
0
,
0
,
0
,
1
,
1
],
[
0
,
0
,
0
,
0
,
1
],
[
0
,
0
,
0
,
1
,
1
],
[
0
,
0
,
0
,
0
,
1
],
[
0
,
0
,
0
,
1
,
0
]]
[
0
,
0
,
0
,
1
,
0
]]
...
...
torch_sparse/__init__.py
View file @
872938af
...
@@ -44,7 +44,7 @@ from .add import add, add_, add_nnz, add_nnz_ # noqa
...
@@ -44,7 +44,7 @@ from .add import add, add_, add_nnz, add_nnz_ # noqa
from
.mul
import
mul
,
mul_
,
mul_nnz
,
mul_nnz_
# noqa
from
.mul
import
mul
,
mul_
,
mul_nnz
,
mul_nnz_
# noqa
from
.reduce
import
sum
,
mean
,
min
,
max
# noqa
from
.reduce
import
sum
,
mean
,
min
,
max
# noqa
from
.matmul
import
matmul
# noqa
from
.matmul
import
matmul
# noqa
from
.cat
import
cat
,
cat_diag
# noqa
from
.cat
import
cat
# noqa
from
.rw
import
random_walk
# noqa
from
.rw
import
random_walk
# noqa
from
.metis
import
partition
# noqa
from
.metis
import
partition
# noqa
from
.bandwidth
import
reverse_cuthill_mckee
# noqa
from
.bandwidth
import
reverse_cuthill_mckee
# noqa
...
@@ -89,7 +89,6 @@ __all__ = [
...
@@ -89,7 +89,6 @@ __all__ = [
'max'
,
'max'
,
'matmul'
,
'matmul'
,
'cat'
,
'cat'
,
'cat_diag'
,
'random_walk'
,
'random_walk'
,
'partition'
,
'partition'
,
'reverse_cuthill_mckee'
,
'reverse_cuthill_mckee'
,
...
...
torch_sparse/cat.py
View file @
872938af
from
typing
import
Optional
,
List
from
typing
import
Optional
,
List
,
Tuple
import
torch
import
torch
from
torch_sparse.storage
import
SparseStorage
from
torch_sparse.storage
import
SparseStorage
from
torch_sparse.tensor
import
SparseTensor
from
torch_sparse.tensor
import
SparseTensor
def
cat
(
tensors
:
List
[
SparseTensor
],
dim
:
int
)
->
SparseTensor
:
@
torch
.
jit
.
_overload
# noqa: F811
def
cat
(
tensors
,
dim
):
# noqa: F811
# type: (List[SparseTensor], int) -> SparseTensor
pass
@
torch
.
jit
.
_overload
# noqa: F811
def
cat
(
tensors
,
dim
):
# noqa: F811
# type: (List[SparseTensor], Tuple[int, int]) -> SparseTensor
pass
@
torch
.
jit
.
_overload
# noqa: F811
def
cat
(
tensors
,
dim
):
# noqa: F811
# type: (List[SparseTensor], List[int]) -> SparseTensor
pass
def
cat
(
tensors
,
dim
):
# noqa: F811
assert
len
(
tensors
)
>
0
assert
len
(
tensors
)
>
0
if
dim
<
0
:
dim
=
tensors
[
0
].
dim
()
+
dim
if
isinstance
(
dim
,
int
):
dim
=
tensors
[
0
].
dim
()
+
dim
if
dim
<
0
else
dim
if
dim
==
0
:
rows
:
List
[
torch
.
Tensor
]
=
[]
if
dim
==
0
:
rowptrs
:
List
[
torch
.
Tensor
]
=
[]
return
cat_first
(
tensors
)
cols
:
List
[
torch
.
Tensor
]
=
[]
values
:
List
[
torch
.
Tensor
]
=
[]
elif
dim
==
1
:
sparse_sizes
:
List
[
int
]
=
[
0
,
0
]
return
cat_second
(
tensors
)
rowcounts
:
List
[
torch
.
Tensor
]
=
[]
pass
nnz
:
int
=
0
elif
dim
>
1
and
dim
<
tensors
[
0
].
dim
():
for
tensor
in
tensors
:
values
=
[]
row
=
tensor
.
storage
.
_row
for
tensor
in
tensors
:
if
row
is
not
None
:
value
=
tensor
.
storage
.
value
()
rows
.
append
(
row
+
sparse_sizes
[
0
])
assert
value
is
not
None
rowptr
=
tensor
.
storage
.
_rowptr
if
rowptr
is
not
None
:
if
len
(
rowptrs
)
>
0
:
rowptr
=
rowptr
[
1
:]
rowptrs
.
append
(
rowptr
+
nnz
)
cols
.
append
(
tensor
.
storage
.
_col
)
value
=
tensor
.
storage
.
_value
if
value
is
not
None
:
values
.
append
(
value
)
values
.
append
(
value
)
value
=
torch
.
cat
(
values
,
dim
=
dim
-
1
)
return
tensors
[
0
].
set_value
(
value
,
layout
=
'coo'
)
rowcount
=
tensor
.
storage
.
_rowcount
else
:
if
rowcount
is
not
None
:
raise
IndexError
(
rowcounts
.
append
(
rowcount
)
(
f
'Dimension out of range: Expected to be in range of '
f
'[
{
-
tensors
[
0
].
dim
()
}
,
{
tensors
[
0
].
dim
()
-
1
}
], but got '
f
'
{
dim
}
.'
))
else
:
assert
isinstance
(
dim
,
(
tuple
,
list
))
assert
len
(
dim
)
==
2
assert
sorted
(
dim
)
==
[
0
,
1
]
return
cat_diag
(
tensors
)
sparse_sizes
[
0
]
+=
tensor
.
sparse_size
(
0
)
sparse_sizes
[
1
]
=
max
(
sparse_sizes
[
1
],
tensor
.
sparse_size
(
1
))
nnz
+=
tensor
.
nnz
()
row
:
Optional
[
torch
.
Tensor
]
=
None
def
cat_first
(
tensors
:
List
[
SparseTensor
])
->
SparseTensor
:
if
len
(
rows
)
==
len
(
tensors
):
rows
:
List
[
torch
.
Tensor
]
=
[]
row
=
torch
.
cat
(
rows
,
dim
=
0
)
rowptrs
:
List
[
torch
.
Tensor
]
=
[]
cols
:
List
[
torch
.
Tensor
]
=
[]
values
:
List
[
torch
.
Tensor
]
=
[]
sparse_sizes
:
List
[
int
]
=
[
0
,
0
]
rowcounts
:
List
[
torch
.
Tensor
]
=
[]
rowptr
:
Optional
[
torch
.
Tensor
]
=
None
nnz
:
int
=
0
if
len
(
rowptrs
)
==
len
(
tensors
):
for
tensor
in
tensors
:
rowptr
=
torch
.
cat
(
rowptrs
,
dim
=
0
)
row
=
tensor
.
storage
.
_row
if
row
is
not
None
:
rows
.
append
(
row
+
sparse_sizes
[
0
])
col
=
torch
.
cat
(
cols
,
dim
=
0
)
rowptr
=
tensor
.
storage
.
_rowptr
if
rowptr
is
not
None
:
rowptrs
.
append
(
rowptr
[
1
:]
+
nnz
if
len
(
rowptrs
)
>
0
else
rowptr
)
value
:
Optional
[
torch
.
Tensor
]
=
None
cols
.
append
(
tensor
.
storage
.
_col
)
if
len
(
values
)
==
len
(
tensors
):
value
=
torch
.
cat
(
values
,
dim
=
0
)
rowcount
:
Optional
[
torch
.
Tensor
]
=
Non
e
value
=
tensor
.
storage
.
_valu
e
if
len
(
rowcounts
)
==
len
(
tensors
)
:
if
value
is
not
None
:
rowcount
=
torch
.
cat
(
rowcounts
,
dim
=
0
)
values
.
append
(
value
)
storage
=
SparseStorage
(
row
=
row
,
rowptr
=
rowptr
,
col
=
col
,
value
=
value
,
rowcount
=
tensor
.
storage
.
_rowcount
sparse_sizes
=
sparse_sizes
,
rowcount
=
rowcount
,
if
rowcount
is
not
None
:
colptr
=
None
,
colcount
=
None
,
csr2csc
=
None
,
rowcounts
.
append
(
rowcount
)
csc2csr
=
None
,
is_sorted
=
True
)
return
tensors
[
0
].
from_storage
(
storage
)
elif
dim
==
1
:
sparse_sizes
[
0
]
+=
tensor
.
sparse_size
(
0
)
rows
:
List
[
torch
.
Tensor
]
=
[]
sparse_sizes
[
1
]
=
max
(
sparse_sizes
[
1
],
tensor
.
sparse_size
(
1
))
cols
:
List
[
torch
.
Tensor
]
=
[]
nnz
+=
tensor
.
nnz
()
values
:
List
[
torch
.
Tensor
]
=
[]
sparse_sizes
:
List
[
int
]
=
[
0
,
0
]
colptrs
:
List
[
torch
.
Tensor
]
=
[]
colcounts
:
List
[
torch
.
Tensor
]
=
[]
nnz
:
int
=
0
row
:
Optional
[
torch
.
Tensor
]
=
None
for
tensor
in
tensors
:
if
len
(
rows
)
==
len
(
tensors
)
:
row
,
col
,
value
=
tensor
.
coo
(
)
row
=
torch
.
cat
(
rows
,
dim
=
0
)
rows
.
append
(
row
)
rowptr
:
Optional
[
torch
.
Tensor
]
=
None
if
len
(
rowptrs
)
==
len
(
tensors
):
rowptr
=
torch
.
cat
(
rowptrs
,
dim
=
0
)
cols
.
append
(
tensor
.
storage
.
_col
+
sparse_sizes
[
1
]
)
col
=
torch
.
cat
(
cols
,
dim
=
0
)
if
value
is
not
None
:
value
:
Optional
[
torch
.
Tensor
]
=
None
values
.
append
(
value
)
if
len
(
values
)
==
len
(
tensors
):
value
=
torch
.
cat
(
values
,
dim
=
0
)
colptr
=
tensor
.
storage
.
_colptr
rowcount
:
Optional
[
torch
.
Tensor
]
=
None
if
colptr
is
not
None
:
if
len
(
rowcounts
)
==
len
(
tensors
):
if
len
(
colptrs
)
>
0
:
rowcount
=
torch
.
cat
(
rowcounts
,
dim
=
0
)
colptr
=
colptr
[
1
:]
colptrs
.
append
(
colptr
+
nnz
)
colcount
=
tensor
.
storage
.
_colcount
storage
=
SparseStorage
(
row
=
row
,
rowptr
=
rowptr
,
col
=
col
,
value
=
value
,
if
colcount
is
not
None
:
sparse_sizes
=
(
sparse_sizes
[
0
],
sparse_sizes
[
1
]),
colcounts
.
append
(
colcount
)
rowcount
=
rowcount
,
colptr
=
None
,
colcount
=
None
,
csr2csc
=
None
,
csc2csr
=
None
,
is_sorted
=
True
)
return
tensors
[
0
].
from_storage
(
storage
)
sparse_sizes
[
0
]
=
max
(
sparse_sizes
[
0
],
tensor
.
sparse_size
(
0
))
sparse_sizes
[
1
]
+=
tensor
.
sparse_size
(
1
)
nnz
+=
tensor
.
nnz
()
row
=
torch
.
cat
(
rows
,
dim
=
0
)
def
cat_second
(
tensors
:
List
[
SparseTensor
])
->
SparseTensor
:
rows
:
List
[
torch
.
Tensor
]
=
[]
cols
:
List
[
torch
.
Tensor
]
=
[]
values
:
List
[
torch
.
Tensor
]
=
[]
sparse_sizes
:
List
[
int
]
=
[
0
,
0
]
colptrs
:
List
[
torch
.
Tensor
]
=
[]
colcounts
:
List
[
torch
.
Tensor
]
=
[]
nnz
:
int
=
0
for
tensor
in
tensors
:
row
,
col
,
value
=
tensor
.
coo
()
rows
.
append
(
row
)
cols
.
append
(
tensor
.
storage
.
_col
+
sparse_sizes
[
1
])
if
value
is
not
None
:
values
.
append
(
value
)
col
=
torch
.
cat
(
cols
,
dim
=
0
)
colptr
=
tensor
.
storage
.
_colptr
if
colptr
is
not
None
:
colptrs
.
append
(
colptr
[
1
:]
+
nnz
if
len
(
colptrs
)
>
0
else
colptr
)
value
:
Optional
[
torch
.
Tensor
]
=
None
colcount
=
tensor
.
storage
.
_colcount
if
len
(
values
)
==
len
(
tensors
)
:
if
colcount
is
not
None
:
value
=
torch
.
cat
(
values
,
dim
=
0
)
colcounts
.
append
(
colcount
)
colptr
:
Optional
[
torch
.
Tensor
]
=
None
sparse_sizes
[
0
]
=
max
(
sparse_sizes
[
0
],
tensor
.
sparse_size
(
0
))
if
len
(
colptrs
)
==
len
(
tensors
):
sparse_sizes
[
1
]
+=
tensor
.
sparse_size
(
1
)
colptr
=
torch
.
cat
(
colptrs
,
dim
=
0
)
nnz
+=
tensor
.
nnz
(
)
colcount
:
Optional
[
torch
.
Tensor
]
=
None
row
=
torch
.
cat
(
rows
,
dim
=
0
)
if
len
(
colcounts
)
==
len
(
tensors
):
col
=
torch
.
cat
(
cols
,
dim
=
0
)
colcount
=
torch
.
cat
(
colcounts
,
dim
=
0
)
storage
=
SparseStorage
(
row
=
row
,
rowptr
=
None
,
col
=
col
,
value
=
value
,
value
:
Optional
[
torch
.
Tensor
]
=
None
sparse_sizes
=
sparse_sizes
,
rowcount
=
None
,
if
len
(
values
)
==
len
(
tensors
):
colptr
=
colptr
,
colcount
=
colcount
,
csr2csc
=
None
,
value
=
torch
.
cat
(
values
,
dim
=
0
)
csc2csr
=
None
,
is_sorted
=
False
)
return
tensors
[
0
].
from_storage
(
storage
)
elif
dim
>
1
and
dim
<
tensors
[
0
].
dim
():
colptr
:
Optional
[
torch
.
Tensor
]
=
None
values
:
List
[
torch
.
Tensor
]
=
[]
if
len
(
colptrs
)
==
len
(
tensors
):
for
tensor
in
tensors
:
colptr
=
torch
.
cat
(
colptrs
,
dim
=
0
)
value
=
tensor
.
storage
.
value
()
if
value
is
not
None
:
values
.
append
(
value
)
value
:
Optional
[
torch
.
Tensor
]
=
None
colcount
:
Optional
[
torch
.
Tensor
]
=
None
if
len
(
value
s
)
==
len
(
tensors
):
if
len
(
colcount
s
)
==
len
(
tensors
):
value
=
torch
.
cat
(
values
,
dim
=
dim
-
1
)
colcount
=
torch
.
cat
(
colcounts
,
dim
=
0
)
return
tensors
[
0
].
set_
value
(
value
,
layout
=
'coo'
)
storage
=
SparseStorage
(
row
=
row
,
rowptr
=
None
,
col
=
col
,
value
=
value
,
else
:
sparse_sizes
=
(
sparse_sizes
[
0
],
sparse_sizes
[
1
]),
raise
IndexError
(
rowcount
=
None
,
colptr
=
colptr
,
colcount
=
colcount
,
(
f
'Dimension out of range: Expected to be in range of '
csr2csc
=
None
,
csc2csr
=
None
,
is_sorted
=
False
)
f
'[
{
-
tensors
[
0
].
dim
()
}
,
{
tensors
[
0
].
dim
()
-
1
}
], but got
{
dim
}
.'
)
)
return
tensors
[
0
].
from_storage
(
storage
)
def
cat_diag
(
tensors
:
List
[
SparseTensor
])
->
SparseTensor
:
def
cat_diag
(
tensors
:
List
[
SparseTensor
])
->
SparseTensor
:
...
@@ -163,9 +187,7 @@ def cat_diag(tensors: List[SparseTensor]) -> SparseTensor:
...
@@ -163,9 +187,7 @@ def cat_diag(tensors: List[SparseTensor]) -> SparseTensor:
rowptr
=
tensor
.
storage
.
_rowptr
rowptr
=
tensor
.
storage
.
_rowptr
if
rowptr
is
not
None
:
if
rowptr
is
not
None
:
if
len
(
rowptrs
)
>
0
:
rowptrs
.
append
(
rowptr
[
1
:]
+
nnz
if
len
(
rowptrs
)
>
0
else
rowptr
)
rowptr
=
rowptr
[
1
:]
rowptrs
.
append
(
rowptr
+
nnz
)
cols
.
append
(
tensor
.
storage
.
_col
+
sparse_sizes
[
1
])
cols
.
append
(
tensor
.
storage
.
_col
+
sparse_sizes
[
1
])
...
@@ -179,9 +201,7 @@ def cat_diag(tensors: List[SparseTensor]) -> SparseTensor:
...
@@ -179,9 +201,7 @@ def cat_diag(tensors: List[SparseTensor]) -> SparseTensor:
colptr
=
tensor
.
storage
.
_colptr
colptr
=
tensor
.
storage
.
_colptr
if
colptr
is
not
None
:
if
colptr
is
not
None
:
if
len
(
colptrs
)
>
0
:
colptrs
.
append
(
colptr
[
1
:]
+
nnz
if
len
(
colptrs
)
>
0
else
colptr
)
colptr
=
colptr
[
1
:]
colptrs
.
append
(
colptr
+
nnz
)
colcount
=
tensor
.
storage
.
_colcount
colcount
=
tensor
.
storage
.
_colcount
if
colcount
is
not
None
:
if
colcount
is
not
None
:
...
@@ -234,7 +254,8 @@ def cat_diag(tensors: List[SparseTensor]) -> SparseTensor:
...
@@ -234,7 +254,8 @@ def cat_diag(tensors: List[SparseTensor]) -> SparseTensor:
csc2csr
=
torch
.
cat
(
csc2csrs
,
dim
=
0
)
csc2csr
=
torch
.
cat
(
csc2csrs
,
dim
=
0
)
storage
=
SparseStorage
(
row
=
row
,
rowptr
=
rowptr
,
col
=
col
,
value
=
value
,
storage
=
SparseStorage
(
row
=
row
,
rowptr
=
rowptr
,
col
=
col
,
value
=
value
,
sparse_sizes
=
sparse_sizes
,
rowcount
=
rowcount
,
sparse_sizes
=
(
sparse_sizes
[
0
],
sparse_sizes
[
1
]),
colptr
=
colptr
,
colcount
=
colcount
,
csr2csc
=
csr2csc
,
rowcount
=
rowcount
,
colptr
=
colptr
,
colcount
=
colcount
,
csr2csc
=
csr2csc
,
csc2csr
=
csc2csr
,
is_sorted
=
True
)
csc2csr
=
csc2csr
,
is_sorted
=
True
)
return
tensors
[
0
].
from_storage
(
storage
)
return
tensors
[
0
].
from_storage
(
storage
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment