Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
bc522dd9
Commit
bc522dd9
authored
Jan 13, 2020
by
rusty1s
Browse files
cat tests
parent
69cab8ac
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
47 additions
and
17 deletions
+47
-17
test/test_cat.py
test/test_cat.py
+29
-6
torch_sparse/cat.py
torch_sparse/cat.py
+18
-11
No files found.
test/test_cat.py
View file @
bc522dd9
from
itertools
import
product
import
pytest
import
pytest
import
torch
import
torch
from
torch_sparse.tensor
import
SparseTensor
from
torch_sparse.tensor
import
SparseTensor
from
torch_sparse.cat
import
cat
from
torch_sparse.cat
import
cat
from
.utils
import
dtypes
,
devices
,
tensor
from
.utils
import
devices
,
tensor
@
pytest
.
mark
.
parametrize
(
'
dtype,
device'
,
product
(
dtypes
,
devices
)
)
@
pytest
.
mark
.
parametrize
(
'device'
,
devices
)
def
test_cat
(
dtype
,
device
):
def
test_cat
(
device
):
index
=
tensor
([[
0
,
0
,
1
],
[
0
,
1
,
2
]],
torch
.
long
,
device
)
index
=
tensor
([[
0
,
0
,
1
],
[
0
,
1
,
2
]],
torch
.
long
,
device
)
mat1
=
SparseTensor
(
index
)
mat1
=
SparseTensor
(
index
)
mat1
.
fill_cache_
()
index
=
tensor
([[
0
,
0
,
1
,
2
],
[
0
,
1
,
1
,
0
]],
torch
.
long
,
device
)
index
=
tensor
([[
0
,
0
,
1
,
2
],
[
0
,
1
,
1
,
0
]],
torch
.
long
,
device
)
mat2
=
SparseTensor
(
index
)
mat2
=
SparseTensor
(
index
)
mat2
.
fill_cache_
()
out
=
cat
([
mat1
,
mat2
],
dim
=
0
)
assert
out
.
to_dense
().
tolist
()
==
[[
1
,
1
,
0
],
[
0
,
0
,
1
],
[
1
,
1
,
0
],
[
0
,
1
,
0
],
[
1
,
0
,
0
]]
assert
len
(
out
.
storage
.
cached_keys
())
==
2
assert
out
.
storage
.
has_rowcount
()
assert
out
.
storage
.
has_rowptr
()
out
=
cat
([
mat1
,
mat2
],
dim
=
1
)
assert
out
.
to_dense
().
tolist
()
==
[[
1
,
1
,
0
,
1
,
1
],
[
0
,
0
,
1
,
0
,
1
],
[
0
,
0
,
0
,
1
,
0
]]
assert
len
(
out
.
storage
.
cached_keys
())
==
2
assert
out
.
storage
.
has_colcount
()
assert
out
.
storage
.
has_colptr
()
out
=
cat
([
mat1
,
mat2
],
dim
=
(
0
,
1
))
assert
out
.
to_dense
().
tolist
()
==
[[
1
,
1
,
0
,
0
,
0
],
[
0
,
0
,
1
,
0
,
0
],
[
0
,
0
,
0
,
1
,
1
],
[
0
,
0
,
0
,
0
,
1
],
[
0
,
0
,
0
,
1
,
0
]]
assert
len
(
out
.
storage
.
cached_keys
())
==
6
cat
([
mat1
,
mat2
],
dim
=
(
0
,
1
))
mat1
.
set_value_
(
torch
.
randn
((
mat1
.
nnz
(),
4
),
device
=
device
))
out
=
cat
([
mat1
,
mat1
],
dim
=-
1
)
assert
out
.
storage
.
value
.
size
()
==
(
mat1
.
nnz
(),
8
)
assert
len
(
out
.
storage
.
cached_keys
())
==
6
torch_sparse/cat.py
View file @
bc522dd9
...
@@ -70,9 +70,11 @@ def cat(tensors, dim):
...
@@ -70,9 +70,11 @@ def cat(tensors, dim):
value
=
torch
.
cat
(
values
,
dim
=
0
)
if
has_value
else
None
,
value
=
torch
.
cat
(
values
,
dim
=
0
)
if
has_value
else
None
,
sparse_size
=
sparse_size
,
sparse_size
=
sparse_size
,
colcount
=
torch
.
cat
(
colcounts
)
if
has_colcount
else
None
,
colcount
=
torch
.
cat
(
colcounts
)
if
has_colcount
else
None
,
colptr
=
torch
.
cat
(
colptrs
)
if
has_colptr
else
None
,
is_sorted
=
False
)
colptr
=
torch
.
cat
(
colptrs
)
if
has_colptr
else
None
,
is_sorted
=
False
,
)
elif
dim
==
(
0
,
1
)
or
(
1
,
0
):
elif
dim
==
(
0
,
1
)
or
dim
==
(
1
,
0
):
for
tensor
in
tensors
:
for
tensor
in
tensors
:
(
row
,
col
),
value
=
tensor
.
coo
()
(
row
,
col
),
value
=
tensor
.
coo
()
rows
+=
[
row
+
sparse_size
[
0
]]
rows
+=
[
row
+
sparse_size
[
0
]]
...
@@ -115,21 +117,26 @@ def cat(tensors, dim):
...
@@ -115,21 +117,26 @@ def cat(tensors, dim):
colptr
=
torch
.
cat
(
colptrs
)
if
has_colptr
else
None
,
colptr
=
torch
.
cat
(
colptrs
)
if
has_colptr
else
None
,
csr2csc
=
torch
.
cat
(
csr2cscs
)
if
has_csr2csc
else
None
,
csr2csc
=
torch
.
cat
(
csr2cscs
)
if
has_csr2csc
else
None
,
csc2csr
=
torch
.
cat
(
csc2csrs
)
if
has_csc2csr
else
None
,
csc2csr
=
torch
.
cat
(
csc2csrs
)
if
has_csc2csr
else
None
,
is_sorted
=
True
)
is_sorted
=
True
,
)
elif
isinstance
(
dim
,
int
)
and
dim
>
1
and
dim
<
tensors
[
0
].
dim
():
elif
isinstance
(
dim
,
int
)
and
dim
>
1
and
dim
<
tensors
[
0
].
dim
():
for
tensor
in
tensors
:
for
tensor
in
tensors
:
values
+=
[
tensor
.
storage
.
value
]
values
+=
[
tensor
.
storage
.
value
]
sparse_size
[
0
]
=
max
(
sparse_size
[
0
],
tensor
.
sparse_size
(
0
))
sparse_size
[
1
]
=
max
(
sparse_size
[
1
],
tensor
.
sparse_size
(
1
))
old_storage
=
tensors
[
0
].
storage
old_storage
=
tensors
[
0
].
storage
storage
=
old_storage
.
storage
.
__class__
(
storage
=
old_storage
.
__class__
(
tensors
[
0
].
storage
.
index
,
value
=
torch
.
cat
(
values
,
dim
=
dim
-
1
),
tensors
[
0
].
storage
.
index
,
sparse_size
=
sparse_size
,
rowcount
=
old_storage
.
_rowcount
,
value
=
torch
.
cat
(
values
,
dim
=
dim
-
1
),
rowptr
=
old_storage
.
_rowcount
,
colcount
=
old_storage
.
_rowcount
,
sparse_size
=
old_storage
.
sparse_size
(),
colptr
=
old_storage
.
_rowcount
,
csr2csc
=
old_storage
.
_csr2csc
,
rowcount
=
old_storage
.
_rowcount
,
csc2csr
=
old_storage
.
_csc2csr
,
is_sorted
=
True
)
rowptr
=
old_storage
.
_rowptr
,
colcount
=
old_storage
.
_colcount
,
colptr
=
old_storage
.
_colptr
,
csr2csc
=
old_storage
.
_csr2csc
,
csc2csr
=
old_storage
.
_csc2csr
,
is_sorted
=
True
,
)
else
:
else
:
raise
IndexError
(
raise
IndexError
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment