Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
436b2e50
"vscode:/vscode.git/clone" did not exist on "8b00a415ab5170a5a75b105402ca262d1fb7ac12"
Commit
436b2e50
authored
Nov 07, 2020
by
rusty1s
Browse files
more memory efficient
parent
906c97e4
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
41 additions
and
7 deletions
+41
-7
test/test_tensor.py
test/test_tensor.py
+18
-0
torch_sparse/storage.py
torch_sparse/storage.py
+0
-1
torch_sparse/tensor.py
torch_sparse/tensor.py
+23
-6
No files found.
test/test_tensor.py
View file @
436b2e50
...
...
@@ -19,3 +19,21 @@ def test_getitem(dtype, device):
assert
mat
[...,
:
10
].
sizes
()
==
[
50
,
10
]
assert
mat
[
idx1
,
idx2
].
sizes
()
==
[
10
,
10
]
assert
mat
[
idx1
.
tolist
()].
sizes
()
==
[
10
,
40
]
@
pytest
.
mark
.
parametrize
(
'device'
,
devices
)
def
test_to_symmetric
(
device
):
row
=
torch
.
tensor
([
0
,
0
,
0
,
1
,
1
],
device
=
device
)
col
=
torch
.
tensor
([
0
,
1
,
2
,
0
,
2
],
device
=
device
)
value
=
torch
.
arange
(
1
,
6
,
device
=
device
)
mat
=
SparseTensor
(
row
=
row
,
col
=
col
,
value
=
value
)
assert
not
mat
.
is_symmetric
()
mat
=
mat
.
to_symmetric
()
assert
mat
.
is_symmetric
()
assert
mat
.
to_dense
().
tolist
()
==
[
[
2
,
6
,
3
],
[
6
,
0
,
5
],
[
3
,
5
,
0
],
]
torch_sparse/storage.py
View file @
436b2e50
...
...
@@ -382,7 +382,6 @@ class SparseStorage(object):
ptr
=
mask
.
nonzero
().
flatten
()
ptr
=
torch
.
cat
([
ptr
,
ptr
.
new_full
((
1
,
),
value
.
size
(
0
))])
value
=
segment_csr
(
value
,
ptr
,
reduce
=
reduce
)
value
=
value
[
0
]
if
isinstance
(
value
,
tuple
)
else
value
return
SparseStorage
(
row
=
row
,
rowptr
=
None
,
col
=
col
,
value
=
value
,
sparse_sizes
=
self
.
_sparse_sizes
,
rowcount
=
None
,
...
...
torch_sparse/tensor.py
View file @
436b2e50
...
...
@@ -3,6 +3,7 @@ from typing import Optional, List, Tuple, Dict, Union, Any
import
torch
import
scipy.sparse
from
torch_scatter
import
segment_csr
from
torch_sparse.storage
import
SparseStorage
,
get_layout
...
...
@@ -270,17 +271,33 @@ class SparseTensor(object):
return
bool
((
value1
==
value2
).
all
())
def
to_symmetric
(
self
,
reduce
:
str
=
"sum"
):
N
=
max
(
self
.
size
(
0
),
self
.
size
(
1
))
row
,
col
,
value
=
self
.
coo
()
idx
=
col
.
new_full
((
2
*
col
.
numel
()
+
1
,
),
-
1
)
idx
[
1
:
row
.
numel
()
+
1
]
=
row
idx
[
row
.
numel
()
+
1
:]
=
col
idx
[
1
:]
*=
N
idx
[
1
:
row
.
numel
()
+
1
]
+=
col
idx
[
row
.
numel
()
+
1
:]
+=
row
idx
,
perm
=
idx
.
sort
()
perm
=
perm
[
1
:].
sub_
(
1
)
mask
=
idx
[
1
:]
>
idx
[:
-
1
]
idx2
=
perm
[
mask
]
row
,
col
=
torch
.
cat
([
row
,
col
],
dim
=
0
),
torch
.
cat
([
col
,
row
],
dim
=
0
)
if
value
is
not
None
:
value
=
torch
.
cat
([
value
,
value
],
dim
=
0
)
ptr
=
mask
.
nonzero
().
flatten
()
ptr
=
torch
.
cat
([
ptr
,
ptr
.
new_full
((
1
,
),
perm
.
size
(
0
))])
value
=
torch
.
cat
([
value
,
value
])[
perm
]
value
=
segment_csr
(
value
,
ptr
,
reduce
=
reduce
)
N
=
max
(
self
.
size
(
0
),
self
.
size
(
1
))
new_row
=
torch
.
cat
([
row
,
col
],
dim
=
0
,
out
=
perm
)[
idx2
]
new_col
=
torch
.
cat
([
col
,
row
],
dim
=
0
,
out
=
perm
)[
idx2
]
out
=
SparseTensor
(
row
=
row
,
rowptr
=
None
,
col
=
col
,
value
=
value
,
sparse_sizes
=
(
N
,
N
),
is_sorted
=
False
)
out
=
out
.
coalesce
(
reduce
)
out
=
SparseTensor
(
row
=
new_row
,
rowptr
=
None
,
col
=
new_col
,
value
=
value
,
sparse_sizes
=
(
N
,
N
),
is_sorted
=
True
)
return
out
def
detach_
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment