Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
872938af
Commit
872938af
authored
Jul 01, 2020
by
rusty1s
Browse files
overload for cat
parent
468aea5b
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
140 additions
and
120 deletions
+140
-120
test/test_cat.py
test/test_cat.py
+2
-2
torch_sparse/__init__.py
torch_sparse/__init__.py
+1
-2
torch_sparse/cat.py
torch_sparse/cat.py
+137
-116
No files found.
test/test_cat.py
View file @
872938af
import
pytest
import
pytest
import
torch
import
torch
from
torch_sparse.tensor
import
SparseTensor
from
torch_sparse.tensor
import
SparseTensor
from
torch_sparse.cat
import
cat
,
cat_diag
from
torch_sparse.cat
import
cat
from
.utils
import
devices
,
tensor
from
.utils
import
devices
,
tensor
...
@@ -31,7 +31,7 @@ def test_cat(device):
...
@@ -31,7 +31,7 @@ def test_cat(device):
assert
not
out
.
storage
.
has_rowptr
()
assert
not
out
.
storage
.
has_rowptr
()
assert
out
.
storage
.
num_cached_keys
()
==
2
assert
out
.
storage
.
num_cached_keys
()
==
2
out
=
cat
_diag
([
mat1
,
mat2
])
out
=
cat
([
mat1
,
mat2
]
,
dim
=
(
0
,
1
)
)
assert
out
.
to_dense
().
tolist
()
==
[[
1
,
1
,
0
,
0
,
0
],
[
0
,
0
,
1
,
0
,
0
],
assert
out
.
to_dense
().
tolist
()
==
[[
1
,
1
,
0
,
0
,
0
],
[
0
,
0
,
1
,
0
,
0
],
[
0
,
0
,
0
,
1
,
1
],
[
0
,
0
,
0
,
0
,
1
],
[
0
,
0
,
0
,
1
,
1
],
[
0
,
0
,
0
,
0
,
1
],
[
0
,
0
,
0
,
1
,
0
]]
[
0
,
0
,
0
,
1
,
0
]]
...
...
torch_sparse/__init__.py
View file @
872938af
...
@@ -44,7 +44,7 @@ from .add import add, add_, add_nnz, add_nnz_ # noqa
...
@@ -44,7 +44,7 @@ from .add import add, add_, add_nnz, add_nnz_ # noqa
from
.mul
import
mul
,
mul_
,
mul_nnz
,
mul_nnz_
# noqa
from
.mul
import
mul
,
mul_
,
mul_nnz
,
mul_nnz_
# noqa
from
.reduce
import
sum
,
mean
,
min
,
max
# noqa
from
.reduce
import
sum
,
mean
,
min
,
max
# noqa
from
.matmul
import
matmul
# noqa
from
.matmul
import
matmul
# noqa
from
.cat
import
cat
,
cat_diag
# noqa
from
.cat
import
cat
# noqa
from
.rw
import
random_walk
# noqa
from
.rw
import
random_walk
# noqa
from
.metis
import
partition
# noqa
from
.metis
import
partition
# noqa
from
.bandwidth
import
reverse_cuthill_mckee
# noqa
from
.bandwidth
import
reverse_cuthill_mckee
# noqa
...
@@ -89,7 +89,6 @@ __all__ = [
...
@@ -89,7 +89,6 @@ __all__ = [
'max'
,
'max'
,
'matmul'
,
'matmul'
,
'cat'
,
'cat'
,
'cat_diag'
,
'random_walk'
,
'random_walk'
,
'partition'
,
'partition'
,
'reverse_cuthill_mckee'
,
'reverse_cuthill_mckee'
,
...
...
torch_sparse/cat.py
View file @
872938af
from
typing
import
Optional
,
List
from
typing
import
Optional
,
List
,
Tuple
import
torch
import
torch
from
torch_sparse.storage
import
SparseStorage
from
torch_sparse.storage
import
SparseStorage
from
torch_sparse.tensor
import
SparseTensor
from
torch_sparse.tensor
import
SparseTensor
def
cat
(
tensors
:
List
[
SparseTensor
],
dim
:
int
)
->
SparseTensor
:
@
torch
.
jit
.
_overload
# noqa: F811
def
cat
(
tensors
,
dim
):
# noqa: F811
# type: (List[SparseTensor], int) -> SparseTensor
pass
@
torch
.
jit
.
_overload
# noqa: F811
def
cat
(
tensors
,
dim
):
# noqa: F811
# type: (List[SparseTensor], Tuple[int, int]) -> SparseTensor
pass
@
torch
.
jit
.
_overload
# noqa: F811
def
cat
(
tensors
,
dim
):
# noqa: F811
# type: (List[SparseTensor], List[int]) -> SparseTensor
pass
def
cat
(
tensors
,
dim
):
# noqa: F811
assert
len
(
tensors
)
>
0
assert
len
(
tensors
)
>
0
if
dim
<
0
:
dim
=
tensors
[
0
].
dim
()
+
dim
if
isinstance
(
dim
,
int
):
dim
=
tensors
[
0
].
dim
()
+
dim
if
dim
<
0
else
dim
if
dim
==
0
:
if
dim
==
0
:
return
cat_first
(
tensors
)
elif
dim
==
1
:
return
cat_second
(
tensors
)
pass
elif
dim
>
1
and
dim
<
tensors
[
0
].
dim
():
values
=
[]
for
tensor
in
tensors
:
value
=
tensor
.
storage
.
value
()
assert
value
is
not
None
values
.
append
(
value
)
value
=
torch
.
cat
(
values
,
dim
=
dim
-
1
)
return
tensors
[
0
].
set_value
(
value
,
layout
=
'coo'
)
else
:
raise
IndexError
(
(
f
'Dimension out of range: Expected to be in range of '
f
'[
{
-
tensors
[
0
].
dim
()
}
,
{
tensors
[
0
].
dim
()
-
1
}
], but got '
f
'
{
dim
}
.'
))
else
:
assert
isinstance
(
dim
,
(
tuple
,
list
))
assert
len
(
dim
)
==
2
assert
sorted
(
dim
)
==
[
0
,
1
]
return
cat_diag
(
tensors
)
def
cat_first
(
tensors
:
List
[
SparseTensor
])
->
SparseTensor
:
rows
:
List
[
torch
.
Tensor
]
=
[]
rows
:
List
[
torch
.
Tensor
]
=
[]
rowptrs
:
List
[
torch
.
Tensor
]
=
[]
rowptrs
:
List
[
torch
.
Tensor
]
=
[]
cols
:
List
[
torch
.
Tensor
]
=
[]
cols
:
List
[
torch
.
Tensor
]
=
[]
...
@@ -26,9 +73,7 @@ def cat(tensors: List[SparseTensor], dim: int) -> SparseTensor:
...
@@ -26,9 +73,7 @@ def cat(tensors: List[SparseTensor], dim: int) -> SparseTensor:
rowptr
=
tensor
.
storage
.
_rowptr
rowptr
=
tensor
.
storage
.
_rowptr
if
rowptr
is
not
None
:
if
rowptr
is
not
None
:
if
len
(
rowptrs
)
>
0
:
rowptrs
.
append
(
rowptr
[
1
:]
+
nnz
if
len
(
rowptrs
)
>
0
else
rowptr
)
rowptr
=
rowptr
[
1
:]
rowptrs
.
append
(
rowptr
+
nnz
)
cols
.
append
(
tensor
.
storage
.
_col
)
cols
.
append
(
tensor
.
storage
.
_col
)
...
@@ -63,12 +108,13 @@ def cat(tensors: List[SparseTensor], dim: int) -> SparseTensor:
...
@@ -63,12 +108,13 @@ def cat(tensors: List[SparseTensor], dim: int) -> SparseTensor:
rowcount
=
torch
.
cat
(
rowcounts
,
dim
=
0
)
rowcount
=
torch
.
cat
(
rowcounts
,
dim
=
0
)
storage
=
SparseStorage
(
row
=
row
,
rowptr
=
rowptr
,
col
=
col
,
value
=
value
,
storage
=
SparseStorage
(
row
=
row
,
rowptr
=
rowptr
,
col
=
col
,
value
=
value
,
sparse_sizes
=
sparse_sizes
,
rowcount
=
rowcount
,
sparse_sizes
=
(
sparse_sizes
[
0
],
sparse_sizes
[
1
])
,
colptr
=
None
,
colcount
=
None
,
csr2csc
=
None
,
rowcount
=
rowcount
,
colptr
=
None
,
colcount
=
None
,
csc2csr
=
None
,
is_sorted
=
True
)
csr2csc
=
None
,
csc2csr
=
None
,
is_sorted
=
True
)
return
tensors
[
0
].
from_storage
(
storage
)
return
tensors
[
0
].
from_storage
(
storage
)
elif
dim
==
1
:
def
cat_second
(
tensors
:
List
[
SparseTensor
])
->
SparseTensor
:
rows
:
List
[
torch
.
Tensor
]
=
[]
rows
:
List
[
torch
.
Tensor
]
=
[]
cols
:
List
[
torch
.
Tensor
]
=
[]
cols
:
List
[
torch
.
Tensor
]
=
[]
values
:
List
[
torch
.
Tensor
]
=
[]
values
:
List
[
torch
.
Tensor
]
=
[]
...
@@ -79,9 +125,7 @@ def cat(tensors: List[SparseTensor], dim: int) -> SparseTensor:
...
@@ -79,9 +125,7 @@ def cat(tensors: List[SparseTensor], dim: int) -> SparseTensor:
nnz
:
int
=
0
nnz
:
int
=
0
for
tensor
in
tensors
:
for
tensor
in
tensors
:
row
,
col
,
value
=
tensor
.
coo
()
row
,
col
,
value
=
tensor
.
coo
()
rows
.
append
(
row
)
rows
.
append
(
row
)
cols
.
append
(
tensor
.
storage
.
_col
+
sparse_sizes
[
1
])
cols
.
append
(
tensor
.
storage
.
_col
+
sparse_sizes
[
1
])
if
value
is
not
None
:
if
value
is
not
None
:
...
@@ -89,9 +133,7 @@ def cat(tensors: List[SparseTensor], dim: int) -> SparseTensor:
...
@@ -89,9 +133,7 @@ def cat(tensors: List[SparseTensor], dim: int) -> SparseTensor:
colptr
=
tensor
.
storage
.
_colptr
colptr
=
tensor
.
storage
.
_colptr
if
colptr
is
not
None
:
if
colptr
is
not
None
:
if
len
(
colptrs
)
>
0
:
colptrs
.
append
(
colptr
[
1
:]
+
nnz
if
len
(
colptrs
)
>
0
else
colptr
)
colptr
=
colptr
[
1
:]
colptrs
.
append
(
colptr
+
nnz
)
colcount
=
tensor
.
storage
.
_colcount
colcount
=
tensor
.
storage
.
_colcount
if
colcount
is
not
None
:
if
colcount
is
not
None
:
...
@@ -102,7 +144,6 @@ def cat(tensors: List[SparseTensor], dim: int) -> SparseTensor:
...
@@ -102,7 +144,6 @@ def cat(tensors: List[SparseTensor], dim: int) -> SparseTensor:
nnz
+=
tensor
.
nnz
()
nnz
+=
tensor
.
nnz
()
row
=
torch
.
cat
(
rows
,
dim
=
0
)
row
=
torch
.
cat
(
rows
,
dim
=
0
)
col
=
torch
.
cat
(
cols
,
dim
=
0
)
col
=
torch
.
cat
(
cols
,
dim
=
0
)
value
:
Optional
[
torch
.
Tensor
]
=
None
value
:
Optional
[
torch
.
Tensor
]
=
None
...
@@ -118,28 +159,11 @@ def cat(tensors: List[SparseTensor], dim: int) -> SparseTensor:
...
@@ -118,28 +159,11 @@ def cat(tensors: List[SparseTensor], dim: int) -> SparseTensor:
colcount
=
torch
.
cat
(
colcounts
,
dim
=
0
)
colcount
=
torch
.
cat
(
colcounts
,
dim
=
0
)
storage
=
SparseStorage
(
row
=
row
,
rowptr
=
None
,
col
=
col
,
value
=
value
,
storage
=
SparseStorage
(
row
=
row
,
rowptr
=
None
,
col
=
col
,
value
=
value
,
sparse_sizes
=
sparse_sizes
,
rowcount
=
None
,
sparse_sizes
=
(
sparse_sizes
[
0
],
sparse_sizes
[
1
])
,
colptr
=
colptr
,
colcount
=
colcount
,
csr2csc
=
None
,
rowcount
=
None
,
colptr
=
colptr
,
colcount
=
colcount
,
csc2csr
=
None
,
is_sorted
=
False
)
csr2csc
=
None
,
csc2csr
=
None
,
is_sorted
=
False
)
return
tensors
[
0
].
from_storage
(
storage
)
return
tensors
[
0
].
from_storage
(
storage
)
elif
dim
>
1
and
dim
<
tensors
[
0
].
dim
():
values
:
List
[
torch
.
Tensor
]
=
[]
for
tensor
in
tensors
:
value
=
tensor
.
storage
.
value
()
if
value
is
not
None
:
values
.
append
(
value
)
value
:
Optional
[
torch
.
Tensor
]
=
None
if
len
(
values
)
==
len
(
tensors
):
value
=
torch
.
cat
(
values
,
dim
=
dim
-
1
)
return
tensors
[
0
].
set_value
(
value
,
layout
=
'coo'
)
else
:
raise
IndexError
(
(
f
'Dimension out of range: Expected to be in range of '
f
'[
{
-
tensors
[
0
].
dim
()
}
,
{
tensors
[
0
].
dim
()
-
1
}
], but got
{
dim
}
.'
))
def
cat_diag
(
tensors
:
List
[
SparseTensor
])
->
SparseTensor
:
def
cat_diag
(
tensors
:
List
[
SparseTensor
])
->
SparseTensor
:
assert
len
(
tensors
)
>
0
assert
len
(
tensors
)
>
0
...
@@ -163,9 +187,7 @@ def cat_diag(tensors: List[SparseTensor]) -> SparseTensor:
...
@@ -163,9 +187,7 @@ def cat_diag(tensors: List[SparseTensor]) -> SparseTensor:
rowptr
=
tensor
.
storage
.
_rowptr
rowptr
=
tensor
.
storage
.
_rowptr
if
rowptr
is
not
None
:
if
rowptr
is
not
None
:
if
len
(
rowptrs
)
>
0
:
rowptrs
.
append
(
rowptr
[
1
:]
+
nnz
if
len
(
rowptrs
)
>
0
else
rowptr
)
rowptr
=
rowptr
[
1
:]
rowptrs
.
append
(
rowptr
+
nnz
)
cols
.
append
(
tensor
.
storage
.
_col
+
sparse_sizes
[
1
])
cols
.
append
(
tensor
.
storage
.
_col
+
sparse_sizes
[
1
])
...
@@ -179,9 +201,7 @@ def cat_diag(tensors: List[SparseTensor]) -> SparseTensor:
...
@@ -179,9 +201,7 @@ def cat_diag(tensors: List[SparseTensor]) -> SparseTensor:
colptr
=
tensor
.
storage
.
_colptr
colptr
=
tensor
.
storage
.
_colptr
if
colptr
is
not
None
:
if
colptr
is
not
None
:
if
len
(
colptrs
)
>
0
:
colptrs
.
append
(
colptr
[
1
:]
+
nnz
if
len
(
colptrs
)
>
0
else
colptr
)
colptr
=
colptr
[
1
:]
colptrs
.
append
(
colptr
+
nnz
)
colcount
=
tensor
.
storage
.
_colcount
colcount
=
tensor
.
storage
.
_colcount
if
colcount
is
not
None
:
if
colcount
is
not
None
:
...
@@ -234,7 +254,8 @@ def cat_diag(tensors: List[SparseTensor]) -> SparseTensor:
...
@@ -234,7 +254,8 @@ def cat_diag(tensors: List[SparseTensor]) -> SparseTensor:
csc2csr
=
torch
.
cat
(
csc2csrs
,
dim
=
0
)
csc2csr
=
torch
.
cat
(
csc2csrs
,
dim
=
0
)
storage
=
SparseStorage
(
row
=
row
,
rowptr
=
rowptr
,
col
=
col
,
value
=
value
,
storage
=
SparseStorage
(
row
=
row
,
rowptr
=
rowptr
,
col
=
col
,
value
=
value
,
sparse_sizes
=
sparse_sizes
,
rowcount
=
rowcount
,
sparse_sizes
=
(
sparse_sizes
[
0
],
sparse_sizes
[
1
]),
colptr
=
colptr
,
colcount
=
colcount
,
csr2csc
=
csr2csc
,
rowcount
=
rowcount
,
colptr
=
colptr
,
colcount
=
colcount
,
csr2csc
=
csr2csc
,
csc2csr
=
csc2csr
,
is_sorted
=
True
)
csc2csr
=
csc2csr
,
is_sorted
=
True
)
return
tensors
[
0
].
from_storage
(
storage
)
return
tensors
[
0
].
from_storage
(
storage
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment