Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
46dac04f
Commit
46dac04f
authored
Jan 19, 2021
by
rusty1s
Browse files
get diag
parent
105a60be
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
34 additions
and
3 deletions
+34
-3
test/test_diag.py
test/test_diag.py
+12
-0
torch_sparse/__init__.py
torch_sparse/__init__.py
+2
-1
torch_sparse/diag.py
torch_sparse/diag.py
+20
-2
No files found.
test/test_diag.py
View file @
46dac04f
...
@@ -52,3 +52,15 @@ def test_fill_diag(dtype, device):
...
@@ -52,3 +52,15 @@ def test_fill_diag(dtype, device):
mat
=
mat
.
fill_diag
(
-
8
,
k
=-
1
)
mat
=
mat
.
fill_diag
(
-
8
,
k
=-
1
)
mat
=
mat
.
fill_diag
(
-
8
,
k
=
1
)
mat
=
mat
.
fill_diag
(
-
8
,
k
=
1
)
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
dtypes
,
devices
))
def
test_get_diag
(
dtype
,
device
):
row
,
col
=
tensor
([[
0
,
0
,
1
,
2
],
[
0
,
1
,
2
,
2
]],
torch
.
long
,
device
)
value
=
tensor
([[
1
,
1
],
[
2
,
2
],
[
3
,
3
],
[
4
,
4
]],
dtype
,
device
)
mat
=
SparseTensor
(
row
=
row
,
col
=
col
,
value
=
value
)
assert
mat
.
get_diag
().
tolist
()
==
[[
1
,
1
],
[
0
,
0
],
[
4
,
4
]]
row
,
col
=
tensor
([[
0
,
0
,
1
,
2
],
[
0
,
1
,
2
,
2
]],
torch
.
long
,
device
)
mat
=
SparseTensor
(
row
=
row
,
col
=
col
)
assert
mat
.
get_diag
().
tolist
()
==
[
1
,
0
,
1
]
torch_sparse/__init__.py
View file @
46dac04f
...
@@ -39,7 +39,7 @@ from .select import select # noqa
...
@@ -39,7 +39,7 @@ from .select import select # noqa
from
.index_select
import
index_select
,
index_select_nnz
# noqa
from
.index_select
import
index_select
,
index_select_nnz
# noqa
from
.masked_select
import
masked_select
,
masked_select_nnz
# noqa
from
.masked_select
import
masked_select
,
masked_select_nnz
# noqa
from
.permute
import
permute
# noqa
from
.permute
import
permute
# noqa
from
.diag
import
remove_diag
,
set_diag
,
fill_diag
# noqa
from
.diag
import
remove_diag
,
set_diag
,
fill_diag
,
get_diag
# noqa
from
.add
import
add
,
add_
,
add_nnz
,
add_nnz_
# noqa
from
.add
import
add
,
add_
,
add_nnz
,
add_nnz_
# noqa
from
.mul
import
mul
,
mul_
,
mul_nnz
,
mul_nnz_
# noqa
from
.mul
import
mul
,
mul_
,
mul_nnz
,
mul_nnz_
# noqa
from
.reduce
import
sum
,
mean
,
min
,
max
# noqa
from
.reduce
import
sum
,
mean
,
min
,
max
# noqa
...
@@ -75,6 +75,7 @@ __all__ = [
...
@@ -75,6 +75,7 @@ __all__ = [
'remove_diag'
,
'remove_diag'
,
'set_diag'
,
'set_diag'
,
'fill_diag'
,
'fill_diag'
,
'get_diag'
,
'add'
,
'add'
,
'add_'
,
'add_'
,
'add_nnz'
,
'add_nnz'
,
...
...
torch_sparse/diag.py
View file @
46dac04f
from
typing
import
Optional
from
typing
import
Optional
import
torch
import
torch
from
torch
import
Tensor
from
torch_sparse.storage
import
SparseStorage
from
torch_sparse.storage
import
SparseStorage
from
torch_sparse.tensor
import
SparseTensor
from
torch_sparse.tensor
import
SparseTensor
...
@@ -31,7 +32,7 @@ def remove_diag(src: SparseTensor, k: int = 0) -> SparseTensor:
...
@@ -31,7 +32,7 @@ def remove_diag(src: SparseTensor, k: int = 0) -> SparseTensor:
return
src
.
from_storage
(
storage
)
return
src
.
from_storage
(
storage
)
def
set_diag
(
src
:
SparseTensor
,
values
:
Optional
[
torch
.
Tensor
]
=
None
,
def
set_diag
(
src
:
SparseTensor
,
values
:
Optional
[
Tensor
]
=
None
,
k
:
int
=
0
)
->
SparseTensor
:
k
:
int
=
0
)
->
SparseTensor
:
src
=
remove_diag
(
src
,
k
=
k
)
src
=
remove_diag
(
src
,
k
=
k
)
row
,
col
,
value
=
src
.
coo
()
row
,
col
,
value
=
src
.
coo
()
...
@@ -51,7 +52,7 @@ def set_diag(src: SparseTensor, values: Optional[torch.Tensor] = None,
...
@@ -51,7 +52,7 @@ def set_diag(src: SparseTensor, values: Optional[torch.Tensor] = None,
new_col
[
mask
]
=
col
new_col
[
mask
]
=
col
new_col
[
inv_mask
]
=
diag
.
add_
(
k
)
new_col
[
inv_mask
]
=
diag
.
add_
(
k
)
new_value
:
Optional
[
torch
.
Tensor
]
=
None
new_value
:
Optional
[
Tensor
]
=
None
if
value
is
not
None
:
if
value
is
not
None
:
new_value
=
value
.
new_empty
((
mask
.
size
(
0
),
)
+
value
.
size
()[
1
:])
new_value
=
value
.
new_empty
((
mask
.
size
(
0
),
)
+
value
.
size
()[
1
:])
new_value
[
mask
]
=
value
new_value
[
mask
]
=
value
...
@@ -92,8 +93,25 @@ def fill_diag(src: SparseTensor, fill_value: float,
...
@@ -92,8 +93,25 @@ def fill_diag(src: SparseTensor, fill_value: float,
return
set_diag
(
src
,
None
,
k
)
return
set_diag
(
src
,
None
,
k
)
def
get_diag
(
src
:
SparseTensor
)
->
Tensor
:
row
,
col
,
value
=
src
.
coo
()
if
value
is
None
:
value
=
torch
.
ones
(
row
.
size
(
0
))
sizes
=
list
(
value
.
size
())
sizes
[
0
]
=
min
(
src
.
size
(
0
),
src
.
size
(
1
))
out
=
value
.
new_zeros
(
sizes
)
mask
=
row
==
col
out
[
row
[
mask
]]
=
value
[
mask
]
return
out
SparseTensor
.
remove_diag
=
lambda
self
,
k
=
0
:
remove_diag
(
self
,
k
)
SparseTensor
.
remove_diag
=
lambda
self
,
k
=
0
:
remove_diag
(
self
,
k
)
SparseTensor
.
set_diag
=
lambda
self
,
values
=
None
,
k
=
0
:
set_diag
(
SparseTensor
.
set_diag
=
lambda
self
,
values
=
None
,
k
=
0
:
set_diag
(
self
,
values
,
k
)
self
,
values
,
k
)
SparseTensor
.
fill_diag
=
lambda
self
,
fill_value
,
k
=
0
:
fill_diag
(
SparseTensor
.
fill_diag
=
lambda
self
,
fill_value
,
k
=
0
:
fill_diag
(
self
,
fill_value
,
k
)
self
,
fill_value
,
k
)
SparseTensor
.
get_diag
=
lambda
self
:
get_diag
(
self
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment