Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
301ecec5
Commit
301ecec5
authored
Jan 13, 2020
by
rusty1s
Browse files
Cat implementation
parent
50ac1233
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
74 additions
and
17 deletions
+74
-17
torch_sparse/cat.py
torch_sparse/cat.py
+72
-16
torch_sparse/reduce.py
torch_sparse/reduce.py
+2
-1
No files found.
torch_sparse/cat.py
View file @
301ecec5
...
...
@@ -4,27 +4,33 @@ import torch
def
cat
(
tensors
,
dim
):
assert
len
(
tensors
)
>
0
has_value
=
tensors
[
0
].
has_value
()
has_rowcount
=
tensors
[
0
].
storage
.
_rowcount
is
not
None
has_rowptr
=
tensors
[
0
].
storage
.
_rowptr
is
not
None
has_colcount
=
tensors
[
0
].
storage
.
_colcount
is
not
None
has_colptr
=
tensors
[
0
].
storage
.
_colptr
is
not
None
has_csr2csc
=
tensors
[
0
].
storage
.
_csr2csc
is
not
None
has_csc2csr
=
tensors
[
0
].
storage
.
_csc2csr
is
not
None
has_rowcount
=
tensors
[
0
].
storage
.
has
_rowcount
()
has_rowptr
=
tensors
[
0
].
storage
.
has
_rowptr
()
has_colcount
=
tensors
[
0
].
storage
.
has
_colcount
()
has_colptr
=
tensors
[
0
].
storage
.
has
_colptr
()
has_csr2csc
=
tensors
[
0
].
storage
.
has
_csr2csc
()
has_csc2csr
=
tensors
[
0
].
storage
.
has
_csc2csr
()
rows
,
cols
,
values
,
sparse_size
=
[],
[],
[],
[
0
,
0
]
rowcounts
,
rowptrs
,
colcounts
,
colptrs
=
[],
[],
[],
[]
csr2cscs
,
csc2csrs
,
nnzs
=
[],
[],
0
if
isinstance
(
dim
,
int
):
dim
=
tensors
[
0
].
dim
()
+
dim
if
dim
<
0
else
dim
else
:
dim
=
tuple
([
tensors
[
0
].
dim
()
+
d
if
d
<
0
else
d
for
d
in
dim
])
if
dim
==
0
:
for
tensor
in
tensors
:
(
row
,
col
),
value
=
tensor
.
coo
()
rows
+=
[
row
+
sparse_size
[
0
]]
cols
+=
[
col
]
values
+=
[
value
]
if
has_value
else
[]
values
+=
[
value
]
sparse_size
[
0
]
+=
tensor
.
sparse_size
(
0
)
sparse_size
[
1
]
=
max
(
sparse_size
[
1
],
tensor
.
sparse_size
(
1
))
rowcounts
+=
[
tensor
.
storage
.
rowcount
]
if
has_rowcount
else
[]
if
has_rowcount
:
rowcounts
+=
[
tensor
.
storage
.
rowcount
]
if
has_rowptr
:
rowptr
=
tensor
.
storage
.
rowptr
...
...
@@ -40,10 +46,33 @@ def cat(tensors, dim):
rowcount
=
torch
.
cat
(
rowcounts
)
if
has_rowcount
else
None
,
rowptr
=
torch
.
cat
(
rowptrs
)
if
has_rowptr
else
None
,
is_sorted
=
True
)
if
dim
==
1
:
raise
NotImplementedError
elif
dim
==
1
:
for
tensor
in
tensors
:
(
row
,
col
),
value
=
tensor
.
coo
()
rows
+=
[
row
]
cols
+=
[
col
+
sparse_size
[
1
]]
values
+=
[
value
]
sparse_size
[
0
]
=
max
(
sparse_size
[
0
],
tensor
.
sparse_size
(
0
))
sparse_size
[
1
]
+=
tensor
.
sparse_size
(
1
)
if
has_colcount
:
colcounts
+=
[
tensor
.
storage
.
colcount
]
if
has
colptr
:
colptr
=
tensor
.
storage
.
colptr
colptr
=
colptr
if
len
(
colptrs
)
==
0
else
colptr
[
1
:]
colptrs
+=
[
colptr
+
nnzs
]
nnzs
+=
tensor
.
nnz
()
storage
=
tensors
[
0
].
storage
.
__class__
(
torch
.
stack
([
torch
.
cat
(
rows
),
torch
.
cat
(
cols
)],
dim
=
0
),
value
=
torch
.
cat
(
values
,
dim
=
0
)
if
has_value
else
None
,
sparse_size
=
sparse_size
,
colcount
=
torch
.
cat
(
colcounts
)
if
has_colcount
else
None
,
colptr
=
torch
.
cat
(
colptrs
)
if
has_colptr
else
None
,
is_sorted
=
False
)
if
dim
==
(
0
,
1
)
or
(
1
,
0
):
el
if
dim
==
(
0
,
1
)
or
(
1
,
0
):
for
tensor
in
tensors
:
(
row
,
col
),
value
=
tensor
.
coo
()
rows
+=
[
row
+
sparse_size
[
0
]]
...
...
@@ -52,21 +81,27 @@ def cat(tensors, dim):
sparse_size
[
0
]
+=
tensor
.
sparse_size
(
0
)
sparse_size
[
1
]
+=
tensor
.
sparse_size
(
1
)
rowcounts
+=
[
tensor
.
storage
.
rowcount
]
if
has_rowcount
else
[]
col
counts
+=
[
tensor
.
storage
.
colcount
]
if
has_colcount
else
[
]
if
has_rowcount
:
row
counts
+=
[
tensor
.
storage
.
rowcount
]
if
has_rowptr
:
rowptr
=
tensor
.
storage
.
rowptr
rowptr
=
rowptr
if
len
(
rowptrs
)
==
0
else
rowptr
[
1
:]
rowptrs
+=
[
rowptr
+
nnzs
]
if
has_colcount
:
colcounts
+=
[
tensor
.
storage
.
colcount
]
if
has_colptr
:
colptr
=
tensor
.
storage
.
colptr
colptr
=
colptr
if
len
(
colptrs
)
==
0
else
colptr
[
1
:]
colptrs
+=
[
colptr
+
nnzs
]
csr2cscs
+=
[
tensor
.
storage
.
csr2csc
+
nnzs
]
if
has_csr2csc
else
[]
csc2csrs
+=
[
tensor
.
storage
.
csc2csr
+
nnzs
]
if
has_csc2csr
else
[]
if
has_csr2csc
:
csr2cscs
+=
[
tensor
.
storage
.
csr2csc
+
nnzs
]
if
has_csc2csr
:
csc2csrs
+=
[
tensor
.
storage
.
csc2csr
+
nnzs
]
nnzs
+=
tensor
.
nnz
()
...
...
@@ -82,7 +117,28 @@ def cat(tensors, dim):
csc2csr
=
torch
.
cat
(
csc2csrs
)
if
has_csc2csr
else
None
,
is_sorted
=
True
)
elif
isinstance
(
dim
,
int
)
and
dim
>
1
and
dim
<
tensors
[
0
].
dim
():
for
tensor
in
tensors
:
values
+=
[
tensor
.
storage
.
value
]
sparse_size
[
0
]
=
max
(
sparse_size
[
0
],
tensor
.
sparse_size
(
0
))
sparse_size
[
1
]
=
max
(
sparse_size
[
1
],
tensor
.
sparse_size
(
1
))
old_storage
=
tensors
[
0
].
storage
storage
=
old_storage
.
storage
.
__class__
(
tensors
[
0
].
storage
.
index
,
value
=
torch
.
cat
(
values
,
dim
=
dim
-
1
),
sparse_size
=
sparse_size
,
rowcount
=
old_storage
.
_rowcount
,
rowptr
=
old_storage
.
_rowcount
,
colcount
=
old_storage
.
_rowcount
,
colptr
=
old_storage
.
_rowcount
,
csr2csc
=
old_storage
.
_csr2csc
,
csc2csr
=
old_storage
.
_csc2csr
,
is_sorted
=
True
)
else
:
raise
NotImplementedError
raise
IndexError
(
(
f
'Dimension out of range: Expected to be in range of '
f
'[
{
-
tensors
[
0
].
dim
()
}
,
{
tensors
[
0
].
dim
()
-
1
}
, but got
{
dim
}
]'
))
return
tensors
[
0
].
__class__
.
from_storage
(
storage
)
torch_sparse/reduce.py
View file @
301ecec5
...
...
@@ -14,7 +14,8 @@ def __reduce__(src, dim=None, reduce='add', deterministic=False):
value
=
src
.
nnz
()
if
reduce
==
'add'
else
1
return
torch
.
tensor
(
value
,
device
=
src
.
device
)
dims
=
[
dim
]
if
isinstance
(
dim
,
int
)
else
sorted
(
list
(
dim
))
dims
=
[
dim
]
if
isinstance
(
dim
,
int
)
else
dim
dims
=
sorted
([
src
.
dim
()
+
dim
if
dim
<
0
else
dim
for
dim
in
dims
])
assert
dims
[
-
1
]
<
src
.
dim
()
rowptr
,
col
,
value
=
src
.
csr
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment