Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
bc184b1a
Commit
bc184b1a
authored
Feb 07, 2020
by
rusty1s
Browse files
slightly faster narrow
parent
0fd9cfe2
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
8 additions
and
6 deletions
+8
-6
torch_sparse/narrow.py
torch_sparse/narrow.py
+5
-3
torch_sparse/storage.py
torch_sparse/storage.py
+3
-3
No files found.
torch_sparse/narrow.py
View file @
bc184b1a
import
copy
from
typing
import
Tuple
from
typing
import
Tuple
import
torch
import
torch
...
@@ -85,13 +86,14 @@ def __narrow_diag__(src: SparseTensor, start: Tuple[int, int],
...
@@ -85,13 +86,14 @@ def __narrow_diag__(src: SparseTensor, start: Tuple[int, int],
length
:
Tuple
[
int
,
int
])
->
SparseTensor
:
length
:
Tuple
[
int
,
int
])
->
SparseTensor
:
# This function builds the inverse operation of `cat_diag` and should hence
# This function builds the inverse operation of `cat_diag` and should hence
# only be used on *diagonally stacked* sparse matrices.
# only be used on *diagonally stacked* sparse matrices.
# That's the reason why this method is marked as *private*.
rowptr
,
col
,
value
=
src
.
csr
()
rowptr
,
col
,
value
=
src
.
csr
()
rowptr
=
rowptr
.
narrow
(
0
,
start
=
start
[
0
],
length
=
length
[
0
]
+
1
)
rowptr
=
rowptr
.
narrow
(
0
,
start
=
start
[
0
],
length
=
length
[
0
]
+
1
)
row_start
=
rowptr
[
0
]
row_start
=
int
(
rowptr
[
0
]
)
rowptr
=
rowptr
-
row_start
rowptr
=
rowptr
-
row_start
row_length
=
rowptr
[
-
1
]
row_length
=
int
(
rowptr
[
-
1
]
)
row
=
src
.
storage
.
_row
row
=
src
.
storage
.
_row
if
row
is
not
None
:
if
row
is
not
None
:
...
@@ -111,7 +113,7 @@ def __narrow_diag__(src: SparseTensor, start: Tuple[int, int],
...
@@ -111,7 +113,7 @@ def __narrow_diag__(src: SparseTensor, start: Tuple[int, int],
colptr
=
src
.
storage
.
_colptr
colptr
=
src
.
storage
.
_colptr
if
colptr
is
not
None
:
if
colptr
is
not
None
:
colptr
=
colptr
.
narrow
(
0
,
start
[
1
],
length
[
1
]
+
1
)
colptr
=
colptr
.
narrow
(
0
,
start
[
1
],
length
[
1
]
+
1
)
colptr
=
colptr
-
colptr
[
0
]
# i.e. `row_start`
colptr
=
colptr
-
int
(
colptr
[
0
]
)
# i.e. `row_start`
colcount
=
src
.
storage
.
_colcount
colcount
=
src
.
storage
.
_colcount
if
colcount
is
not
None
:
if
colcount
is
not
None
:
...
...
torch_sparse/storage.py
View file @
bc184b1a
...
@@ -144,12 +144,12 @@ class SparseStorage(object):
...
@@ -144,12 +144,12 @@ class SparseStorage(object):
self
.
_csc2csr
=
csc2csr
self
.
_csc2csr
=
csc2csr
if
not
is_sorted
:
if
not
is_sorted
:
idx
=
col
.
new_zeros
(
col
.
numel
()
+
1
)
idx
=
self
.
_
col
.
new_zeros
(
self
.
_
col
.
numel
()
+
1
)
idx
[
1
:]
=
sparse_sizes
[
1
]
*
self
.
row
()
+
col
idx
[
1
:]
=
self
.
_
sparse_sizes
[
1
]
*
self
.
row
()
+
self
.
_
col
if
(
idx
[
1
:]
<
idx
[:
-
1
]).
any
():
if
(
idx
[
1
:]
<
idx
[:
-
1
]).
any
():
perm
=
idx
[
1
:].
argsort
()
perm
=
idx
[
1
:].
argsort
()
self
.
_row
=
self
.
row
()[
perm
]
self
.
_row
=
self
.
row
()[
perm
]
self
.
_col
=
col
[
perm
]
self
.
_col
=
self
.
_
col
[
perm
]
if
value
is
not
None
:
if
value
is
not
None
:
self
.
_value
=
value
[
perm
]
self
.
_value
=
value
[
perm
]
self
.
_csr2csc
=
None
self
.
_csr2csc
=
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment