Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
7636e1d1
Commit
7636e1d1
authored
Feb 03, 2020
by
rusty1s
Browse files
diag fix
parent
c86527dc
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
44 additions
and
18 deletions
+44
-18
test/test_diag.py
test/test_diag.py
+24
-14
torch_sparse/__init__.py
torch_sparse/__init__.py
+1
-1
torch_sparse/diag.py
torch_sparse/diag.py
+18
-2
torch_sparse/tensor.py
torch_sparse/tensor.py
+1
-1
No files found.
test/test_diag.py
View file @
7636e1d1
...
@@ -15,23 +15,23 @@ def test_remove_diag(dtype, device):
...
@@ -15,23 +15,23 @@ def test_remove_diag(dtype, device):
mat
.
fill_cache_
()
mat
.
fill_cache_
()
mat
=
mat
.
remove_diag
()
mat
=
mat
.
remove_diag
()
assert
mat
.
storage
.
row
.
tolist
()
==
[
0
,
1
]
assert
mat
.
storage
.
row
()
.
tolist
()
==
[
0
,
1
]
assert
mat
.
storage
.
col
.
tolist
()
==
[
1
,
2
]
assert
mat
.
storage
.
col
()
.
tolist
()
==
[
1
,
2
]
assert
mat
.
storage
.
value
.
tolist
()
==
[
2
,
3
]
assert
mat
.
storage
.
value
()
.
tolist
()
==
[
2
,
3
]
assert
len
(
mat
.
cached_keys
()
)
==
2
assert
mat
.
storage
.
num_
cached_keys
()
==
2
assert
mat
.
storage
.
rowcount
.
tolist
()
==
[
1
,
1
,
0
]
assert
mat
.
storage
.
rowcount
()
.
tolist
()
==
[
1
,
1
,
0
]
assert
mat
.
storage
.
colcount
.
tolist
()
==
[
0
,
1
,
1
]
assert
mat
.
storage
.
colcount
()
.
tolist
()
==
[
0
,
1
,
1
]
mat
=
SparseTensor
(
row
=
row
,
col
=
col
,
value
=
value
)
mat
=
SparseTensor
(
row
=
row
,
col
=
col
,
value
=
value
)
mat
.
fill_cache_
()
mat
.
fill_cache_
()
mat
=
mat
.
remove_diag
(
k
=
1
)
mat
=
mat
.
remove_diag
(
k
=
1
)
assert
mat
.
storage
.
row
.
tolist
()
==
[
0
,
2
]
assert
mat
.
storage
.
row
()
.
tolist
()
==
[
0
,
2
]
assert
mat
.
storage
.
col
.
tolist
()
==
[
0
,
2
]
assert
mat
.
storage
.
col
()
.
tolist
()
==
[
0
,
2
]
assert
mat
.
storage
.
value
.
tolist
()
==
[
1
,
4
]
assert
mat
.
storage
.
value
()
.
tolist
()
==
[
1
,
4
]
assert
len
(
mat
.
cached_keys
()
)
==
2
assert
mat
.
storage
.
num_
cached_keys
()
==
2
assert
mat
.
storage
.
rowcount
.
tolist
()
==
[
1
,
0
,
1
]
assert
mat
.
storage
.
rowcount
()
.
tolist
()
==
[
1
,
0
,
1
]
assert
mat
.
storage
.
colcount
.
tolist
()
==
[
1
,
0
,
1
]
assert
mat
.
storage
.
colcount
()
.
tolist
()
==
[
1
,
0
,
1
]
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
dtypes
,
devices
))
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
dtypes
,
devices
))
...
@@ -40,5 +40,15 @@ def test_set_diag(dtype, device):
...
@@ -40,5 +40,15 @@ def test_set_diag(dtype, device):
value
=
tensor
([
1
,
2
,
3
,
4
],
dtype
,
device
)
value
=
tensor
([
1
,
2
,
3
,
4
],
dtype
,
device
)
mat
=
SparseTensor
(
row
=
row
,
col
=
col
,
value
=
value
)
mat
=
SparseTensor
(
row
=
row
,
col
=
col
,
value
=
value
)
k
=
-
8
mat
=
mat
.
set_diag
(
tensor
([
-
8
,
-
8
],
dtype
,
device
),
k
=-
1
)
mat
=
mat
.
set_diag
(
k
)
mat
=
mat
.
set_diag
(
tensor
([
-
8
],
dtype
,
device
),
k
=
1
)
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
dtypes
,
devices
))
def
test_fill_diag
(
dtype
,
device
):
row
,
col
=
tensor
([[
0
,
0
,
9
,
9
],
[
0
,
1
,
0
,
1
]],
torch
.
long
,
device
)
value
=
tensor
([
1
,
2
,
3
,
4
],
dtype
,
device
)
mat
=
SparseTensor
(
row
=
row
,
col
=
col
,
value
=
value
)
mat
=
mat
.
fill_diag
(
-
8
,
k
=-
1
)
mat
=
mat
.
fill_diag
(
-
8
,
k
=
1
)
torch_sparse/__init__.py
View file @
7636e1d1
...
@@ -27,7 +27,7 @@ from .narrow import narrow
...
@@ -27,7 +27,7 @@ from .narrow import narrow
from
.select
import
select
from
.select
import
select
from
.index_select
import
index_select
,
index_select_nnz
from
.index_select
import
index_select
,
index_select_nnz
from
.masked_select
import
masked_select
,
masked_select_nnz
from
.masked_select
import
masked_select
,
masked_select_nnz
from
.diag
import
set_diag
,
remove
_diag
from
.diag
import
remove_diag
,
set_diag
,
fill
_diag
from
.add
import
add
,
add_
,
add_nnz
,
add_nnz_
from
.add
import
add
,
add_
,
add_nnz
,
add_nnz_
from
.mul
import
mul
,
mul_
,
mul_nnz
,
mul_nnz_
from
.mul
import
mul
,
mul_
,
mul_nnz
,
mul_nnz_
from
.reduce
import
sum
,
mean
,
min
,
max
from
.reduce
import
sum
,
mean
,
min
,
max
...
...
torch_sparse/diag.py
View file @
7636e1d1
...
@@ -50,7 +50,7 @@ def remove_diag(src: SparseTensor, k: int = 0) -> SparseTensor:
...
@@ -50,7 +50,7 @@ def remove_diag(src: SparseTensor, k: int = 0) -> SparseTensor:
@
torch
.
jit
.
script
@
torch
.
jit
.
script
def
set_diag
(
src
:
SparseTensor
,
values
:
Optional
[
torch
.
Tensor
]
=
None
,
def
set_diag
(
src
:
SparseTensor
,
values
:
Optional
[
torch
.
Tensor
]
=
None
,
k
:
int
=
0
)
->
SparseTensor
:
k
:
int
=
0
)
->
SparseTensor
:
src
=
remove_diag
(
src
,
k
=
0
)
src
=
remove_diag
(
src
,
k
=
k
)
row
,
col
,
value
=
src
.
coo
()
row
,
col
,
value
=
src
.
coo
()
mask
=
torch
.
ops
.
torch_sparse
.
non_diag_mask
(
row
,
col
,
src
.
size
(
0
),
mask
=
torch
.
ops
.
torch_sparse
.
non_diag_mask
(
row
,
col
,
src
.
size
(
0
),
...
@@ -65,7 +65,7 @@ def set_diag(src: SparseTensor, values: Optional[torch.Tensor] = None,
...
@@ -65,7 +65,7 @@ def set_diag(src: SparseTensor, values: Optional[torch.Tensor] = None,
new_row
[
inv_mask
]
=
diag
new_row
[
inv_mask
]
=
diag
new_col
=
col
.
new_empty
(
mask
.
size
(
0
))
new_col
=
col
.
new_empty
(
mask
.
size
(
0
))
new_col
[
mask
]
=
row
new_col
[
mask
]
=
col
new_col
[
inv_mask
]
=
diag
.
add_
(
k
)
new_col
[
inv_mask
]
=
diag
.
add_
(
k
)
new_value
:
Optional
[
torch
.
Tensor
]
=
None
new_value
:
Optional
[
torch
.
Tensor
]
=
None
...
@@ -95,6 +95,22 @@ def set_diag(src: SparseTensor, values: Optional[torch.Tensor] = None,
...
@@ -95,6 +95,22 @@ def set_diag(src: SparseTensor, values: Optional[torch.Tensor] = None,
return
src
.
from_storage
(
storage
)
return
src
.
from_storage
(
storage
)
@
torch
.
jit
.
script
def
fill_diag
(
src
:
SparseTensor
,
fill_value
:
int
,
k
:
int
=
0
)
->
SparseTensor
:
num_diag
=
min
(
src
.
sparse_size
(
0
),
src
.
sparse_size
(
1
)
-
k
)
if
k
<
0
:
num_diag
=
min
(
src
.
sparse_size
(
0
)
+
k
,
src
.
sparse_size
(
1
))
value
=
src
.
storage
.
value
()
if
value
is
not
None
:
sizes
=
[
num_diag
]
+
src
.
sizes
()[
2
:]
return
set_diag
(
src
,
value
.
new_full
(
sizes
,
fill_value
),
k
)
else
:
return
set_diag
(
src
,
None
,
k
)
SparseTensor
.
remove_diag
=
lambda
self
,
k
=
0
:
remove_diag
(
self
,
k
)
SparseTensor
.
remove_diag
=
lambda
self
,
k
=
0
:
remove_diag
(
self
,
k
)
SparseTensor
.
set_diag
=
lambda
self
,
values
=
None
,
k
=
0
:
set_diag
(
SparseTensor
.
set_diag
=
lambda
self
,
values
=
None
,
k
=
0
:
set_diag
(
self
,
values
,
k
)
self
,
values
,
k
)
SparseTensor
.
fill_diag
=
lambda
self
,
fill_value
,
k
=
0
:
fill_diag
(
self
,
fill_value
,
k
)
torch_sparse/tensor.py
View file @
7636e1d1
...
@@ -197,7 +197,7 @@ class SparseTensor(object):
...
@@ -197,7 +197,7 @@ class SparseTensor(object):
sizes
=
self
.
sparse_sizes
()
sizes
=
self
.
sparse_sizes
()
value
=
self
.
storage
.
value
()
value
=
self
.
storage
.
value
()
if
value
is
not
None
:
if
value
is
not
None
:
sizes
=
sizes
+
value
.
size
()[
1
:]
sizes
=
list
(
sizes
)
+
list
(
value
.
size
()
)
[
1
:]
return
sizes
return
sizes
def
size
(
self
,
dim
:
int
)
->
int
:
def
size
(
self
,
dim
:
int
)
->
int
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment