Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
4242e343
Commit
4242e343
authored
Jan 20, 2020
by
rusty1s
Browse files
eye implementation
parent
fc183212
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
41 additions
and
3 deletions
+41
-3
test/test_diag.py
test/test_diag.py
+1
-1
torch_sparse/diag.py
torch_sparse/diag.py
+5
-2
torch_sparse/tensor.py
torch_sparse/tensor.py
+35
-0
No files found.
test/test_diag.py
View file @
4242e343
...
@@ -52,7 +52,7 @@ def test_set_diag(dtype, device):
...
@@ -52,7 +52,7 @@ def test_set_diag(dtype, device):
print
()
print
()
k
=
-
8
k
=
-
8
print
(
"k = "
,
k
)
print
(
"k = "
,
k
)
mat
=
mat
.
remove
_diag
(
k
)
mat
=
mat
.
set
_diag
(
k
)
print
(
mat
.
to_dense
())
print
(
mat
.
to_dense
())
# row, col = mat.storage.index
# row, col = mat.storage.index
...
...
torch_sparse/diag.py
View file @
4242e343
...
@@ -39,7 +39,10 @@ def remove_diag(src, k=0):
...
@@ -39,7 +39,10 @@ def remove_diag(src, k=0):
return
src
.
__class__
.
from_storage
(
storage
)
return
src
.
__class__
.
from_storage
(
storage
)
def
set_diag
(
src
,
value
=
None
,
k
=
0
):
def
set_diag
(
src
,
values
=
None
,
k
=
0
):
if
values
is
not
None
and
not
src
.
has_value
():
raise
ValueError
(
'Sparse matrix has no values'
)
src
=
src
.
remove_diag
(
k
=
0
)
src
=
src
.
remove_diag
(
k
=
0
)
index
,
value
=
src
.
coo
()
index
,
value
=
src
.
coo
()
...
@@ -63,7 +66,7 @@ def set_diag(src, value=None, k=0):
...
@@ -63,7 +66,7 @@ def set_diag(src, value=None, k=0):
if
src
.
has_value
():
if
src
.
has_value
():
new_value
=
torch
.
new_empty
((
mask
.
size
(
0
),
)
+
mask
.
size
()[
1
:])
new_value
=
torch
.
new_empty
((
mask
.
size
(
0
),
)
+
mask
.
size
()[
1
:])
new_value
[
mask
]
=
value
new_value
[
mask
]
=
value
new_value
[
inv_mask
]
=
1
new_value
[
inv_mask
]
=
values
if
values
is
not
None
else
1
rowcount
=
None
rowcount
=
None
if
src
.
storage
.
has_rowcount
():
if
src
.
storage
.
has_rowcount
():
...
...
torch_sparse/tensor.py
View file @
4242e343
...
@@ -61,6 +61,41 @@ class SparseTensor(object):
...
@@ -61,6 +61,41 @@ class SparseTensor(object):
return
SparseTensor
.
from_storage
(
storage
)
return
SparseTensor
.
from_storage
(
storage
)
@
classmethod
def
eye
(
self
,
m
,
n
=
None
,
device
=
None
,
no_value
=
True
,
fill_cache
=
False
):
n
=
m
if
n
is
None
else
n
index
=
torch
.
empty
((
2
,
min
(
m
,
n
)),
dtype
=
torch
.
long
,
device
=
device
)
torch
.
arange
(
index
.
size
(
1
),
out
=
index
[
0
])
torch
.
arange
(
index
.
size
(
1
),
out
=
index
[
1
])
value
=
None
if
not
no_value
:
value
=
torch
.
ones
(
index
.
size
(
1
),
device
=
device
)
rowcount
=
rowptr
=
colcount
=
colptr
=
csr2csc
=
csc2csr
=
None
if
fill_cache
:
rowcount
=
index
.
new_ones
(
m
)
rowptr
=
torch
.
arange
(
m
+
1
,
device
=
device
)
colcount
=
index
.
new_ones
(
n
)
colptr
=
torch
.
arange
(
n
+
1
,
device
=
device
)
csr2csc
=
torch
.
arange
(
index
.
size
(
1
),
device
=
device
)
csc2csr
=
torch
.
arange
(
index
.
size
(
1
),
device
=
device
)
storage
=
SparseStorage
(
index
,
value
,
torch
.
Size
([
m
,
n
]),
rowcount
=
rowcount
,
rowptr
=
rowptr
,
colcount
=
colcount
,
colptr
=
colptr
,
csr2csc
=
csr2csc
,
csc2csr
=
csc2csr
,
is_sorted
=
True
,
)
return
SparseTensor
.
from_storage
(
storage
)
def
__copy__
(
self
):
def
__copy__
(
self
):
return
self
.
from_storage
(
self
.
storage
)
return
self
.
from_storage
(
self
.
storage
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment