Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
d30ed1d5
Commit
d30ed1d5
authored
Dec 21, 2019
by
rusty1s
Browse files
cat
parent
c2bd0b10
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
108 additions
and
0 deletions
+108
-0
test/test_cat.py
test/test_cat.py
+19
-0
torch_sparse/cat.py
torch_sparse/cat.py
+89
-0
No files found.
test/test_cat.py
0 → 100644
View file @
d30ed1d5
from
itertools
import
product
import
pytest
import
torch
from
torch_sparse.tensor
import
SparseTensor
from
torch_sparse.cat
import
cat
from
.utils
import
dtypes
,
devices
,
tensor
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
dtypes
,
devices
))
def
test_cat
(
dtype
,
device
):
index
=
tensor
([[
0
,
0
,
1
],
[
0
,
1
,
2
]],
torch
.
long
,
device
)
mat1
=
SparseTensor
(
index
)
index
=
tensor
([[
0
,
0
,
1
,
2
],
[
0
,
1
,
1
,
0
]],
torch
.
long
,
device
)
mat2
=
SparseTensor
(
index
)
cat
([
mat1
,
mat2
],
dim
=
(
0
,
1
))
torch_sparse/cat.py
0 → 100644
View file @
d30ed1d5
import
torch
def
cat
(
tensors
,
dim
):
assert
len
(
tensors
)
>
0
has_value
=
tensors
[
0
].
has_value
()
has_rowcount
=
tensors
[
0
].
storage
.
_rowcount
is
not
None
has_rowptr
=
tensors
[
0
].
storage
.
_rowptr
is
not
None
has_colcount
=
tensors
[
0
].
storage
.
_colcount
is
not
None
has_colptr
=
tensors
[
0
].
storage
.
_colptr
is
not
None
has_csr2csc
=
tensors
[
0
].
storage
.
_csr2csc
is
not
None
has_csc2csr
=
tensors
[
0
].
storage
.
_csc2csr
is
not
None
rows
,
cols
,
values
,
sparse_size
=
[],
[],
[],
[
0
,
0
]
rowcounts
,
rowptrs
,
colcounts
,
colptrs
=
[],
[],
[],
[]
csr2cscs
,
csc2csrs
,
nnzs
=
[],
[],
0
if
dim
==
0
:
for
tensor
in
tensors
:
(
row
,
col
),
value
=
tensor
.
coo
()
rows
+=
[
row
+
sparse_size
[
0
]]
cols
+=
[
col
]
values
+=
[
value
]
if
has_value
else
[]
sparse_size
[
0
]
+=
tensor
.
sparse_size
(
0
)
sparse_size
[
1
]
=
max
(
sparse_size
[
1
],
tensor
.
sparse_size
(
1
))
rowcounts
+=
[
tensor
.
storage
.
rowcount
]
if
has_rowcount
else
[]
if
has_rowptr
:
rowptr
=
tensor
.
storage
.
rowptr
rowptr
=
rowptr
if
len
(
rowptrs
)
==
0
else
rowptr
[
1
:]
rowptrs
+=
[
rowptr
+
nnzs
]
nnzs
+=
tensor
.
nnz
()
storage
=
tensors
[
0
].
storage
.
__class__
(
torch
.
stack
([
torch
.
cat
(
rows
),
torch
.
cat
(
cols
)],
dim
=
0
),
value
=
torch
.
cat
(
values
,
dim
=
0
)
if
has_value
else
None
,
sparse_size
=
sparse_size
,
rowcount
=
torch
.
cat
(
rowcounts
)
if
has_rowcount
else
None
,
rowptr
=
torch
.
cat
(
rowptrs
)
if
has_rowptr
else
None
,
is_sorted
=
True
)
if
dim
==
1
:
raise
NotImplementedError
if
dim
==
(
0
,
1
)
or
(
1
,
0
):
for
tensor
in
tensors
:
(
row
,
col
),
value
=
tensor
.
coo
()
rows
+=
[
row
+
sparse_size
[
0
]]
cols
+=
[
col
+
sparse_size
[
1
]]
values
+=
[
value
]
if
has_value
else
[]
sparse_size
[
0
]
+=
tensor
.
sparse_size
(
0
)
sparse_size
[
1
]
+=
tensor
.
sparse_size
(
1
)
rowcounts
+=
[
tensor
.
storage
.
rowcount
]
if
has_rowcount
else
[]
colcounts
+=
[
tensor
.
storage
.
colcount
]
if
has_colcount
else
[]
if
has_rowptr
:
rowptr
=
tensor
.
storage
.
rowptr
rowptr
=
rowptr
if
len
(
rowptrs
)
==
0
else
rowptr
[
1
:]
rowptrs
+=
[
rowptr
+
nnzs
]
if
has_colptr
:
colptr
=
tensor
.
storage
.
colptr
colptr
=
colptr
if
len
(
colptrs
)
==
0
else
colptr
[
1
:]
colptrs
+=
[
colptr
+
nnzs
]
csr2cscs
+=
[
tensor
.
storage
.
csr2csc
+
nnzs
]
if
has_csr2csc
else
[]
csc2csrs
+=
[
tensor
.
storage
.
csc2csr
+
nnzs
]
if
has_csc2csr
else
[]
nnzs
+=
tensor
.
nnz
()
storage
=
tensors
[
0
].
storage
.
__class__
(
torch
.
stack
([
torch
.
cat
(
rows
),
torch
.
cat
(
cols
)],
dim
=
0
),
value
=
torch
.
cat
(
values
,
dim
=
0
)
if
has_value
else
None
,
sparse_size
=
sparse_size
,
rowcount
=
torch
.
cat
(
rowcounts
)
if
has_rowcount
else
None
,
rowptr
=
torch
.
cat
(
rowptrs
)
if
has_rowptr
else
None
,
colcount
=
torch
.
cat
(
colcounts
)
if
has_colcount
else
None
,
colptr
=
torch
.
cat
(
colptrs
)
if
has_colptr
else
None
,
csr2csc
=
torch
.
cat
(
csr2cscs
)
if
has_csr2csc
else
None
,
csc2csr
=
torch
.
cat
(
csc2csrs
)
if
has_csc2csr
else
None
,
is_sorted
=
True
)
else
:
raise
NotImplementedError
return
tensors
[
0
].
__class__
.
from_storage
(
storage
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment