Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
6b013fbe
Commit
6b013fbe
authored
Feb 04, 2020
by
rusty1s
Browse files
narrow_diag
parent
08ceec29
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
63 additions
and
3 deletions
+63
-3
torch_sparse/__init__.py
torch_sparse/__init__.py
+2
-1
torch_sparse/narrow.py
torch_sparse/narrow.py
+61
-2
No files found.
torch_sparse/__init__.py
View file @
6b013fbe
from
.storage
import
SparseStorage
from
.storage
import
SparseStorage
from
.tensor
import
SparseTensor
from
.tensor
import
SparseTensor
from
.transpose
import
t
from
.transpose
import
t
from
.narrow
import
narrow
from
.narrow
import
narrow
,
__narrow_diag__
from
.select
import
select
from
.select
import
select
from
.index_select
import
index_select
,
index_select_nnz
from
.index_select
import
index_select
,
index_select_nnz
from
.masked_select
import
masked_select
,
masked_select_nnz
from
.masked_select
import
masked_select
,
masked_select_nnz
...
@@ -26,6 +26,7 @@ __all__ = [
...
@@ -26,6 +26,7 @@ __all__ = [
'SparseTensor'
,
'SparseTensor'
,
't'
,
't'
,
'narrow'
,
'narrow'
,
'__narrow_diag__'
,
'select'
,
'select'
,
'index_select'
,
'index_select'
,
'index_select_nnz'
,
'index_select_nnz'
,
...
...
torch_sparse/narrow.py
View file @
6b013fbe
from
typing
import
Tuple
import
torch
import
torch
from
torch_sparse.storage
import
SparseStorage
from
torch_sparse.storage
import
SparseStorage
from
torch_sparse.tensor
import
SparseTensor
from
torch_sparse.tensor
import
SparseTensor
...
@@ -6,8 +8,11 @@ from torch_sparse.tensor import SparseTensor
...
@@ -6,8 +8,11 @@ from torch_sparse.tensor import SparseTensor
@
torch
.
jit
.
script
@
torch
.
jit
.
script
def
narrow
(
src
:
SparseTensor
,
dim
:
int
,
start
:
int
,
def
narrow
(
src
:
SparseTensor
,
dim
:
int
,
start
:
int
,
length
:
int
)
->
SparseTensor
:
length
:
int
)
->
SparseTensor
:
dim
=
src
.
dim
()
+
dim
if
dim
<
0
else
dim
if
dim
<
0
:
start
=
src
.
size
(
dim
)
+
start
if
start
<
0
else
start
dim
=
src
.
dim
()
+
dim
if
start
<
0
:
start
=
src
.
size
(
dim
)
+
start
if
dim
==
0
:
if
dim
==
0
:
rowptr
,
col
,
value
=
src
.
csr
()
rowptr
,
col
,
value
=
src
.
csr
()
...
@@ -75,5 +80,59 @@ def narrow(src: SparseTensor, dim: int, start: int,
...
@@ -75,5 +80,59 @@ def narrow(src: SparseTensor, dim: int, start: int,
raise
ValueError
raise
ValueError
@
torch
.
jit
.
script
def
__narrow_diag__
(
src
:
SparseTensor
,
start
:
Tuple
[
int
,
int
],
length
:
Tuple
[
int
,
int
])
->
SparseTensor
:
# This function builds the inverse operation of `cat_diag` and should hence
# only be used on *diagonally stacked* sparse matrices.
rowptr
,
col
,
value
=
src
.
csr
()
rowptr
=
rowptr
.
narrow
(
0
,
start
=
start
[
0
],
length
=
length
[
0
]
+
1
)
row_start
=
rowptr
[
0
]
rowptr
=
rowptr
-
row_start
row_length
=
rowptr
[
-
1
]
row
=
src
.
storage
.
_row
if
row
is
not
None
:
row
=
row
.
narrow
(
0
,
row_start
,
row_length
)
-
start
[
0
]
col
=
col
.
narrow
(
0
,
row_start
,
row_length
)
-
start
[
1
]
if
value
is
not
None
:
value
=
value
.
narrow
(
0
,
row_start
,
row_length
)
sparse_sizes
=
length
rowcount
=
src
.
storage
.
_rowcount
if
rowcount
is
not
None
:
rowcount
=
rowcount
.
narrow
(
0
,
start
[
0
],
length
[
0
])
colptr
=
src
.
storage
.
_colptr
if
colptr
is
not
None
:
colptr
=
colptr
.
narrow
(
0
,
start
[
1
],
length
[
1
]
+
1
)
colptr
=
colptr
-
colptr
[
0
]
# i.e. `row_start`
colcount
=
src
.
storage
.
_colcount
if
colcount
is
not
None
:
colcount
=
colcount
.
narrow
(
0
,
start
[
1
],
length
[
1
])
csr2csc
=
src
.
storage
.
_csr2csc
if
csr2csc
is
not
None
:
csr2csc
=
csr2csc
.
narrow
(
0
,
row_start
,
row_length
)
-
row_start
csc2csr
=
src
.
storage
.
_csc2csr
if
csc2csr
is
not
None
:
csc2csr
=
csc2csr
.
narrow
(
0
,
row_start
,
row_length
)
-
row_start
storage
=
SparseStorage
(
row
=
row
,
rowptr
=
rowptr
,
col
=
col
,
value
=
value
,
sparse_sizes
=
sparse_sizes
,
rowcount
=
rowcount
,
colptr
=
colptr
,
colcount
=
colcount
,
csr2csc
=
csr2csc
,
csc2csr
=
csc2csr
,
is_sorted
=
True
)
return
src
.
from_storage
(
storage
)
SparseTensor
.
narrow
=
lambda
self
,
dim
,
start
,
length
:
narrow
(
SparseTensor
.
narrow
=
lambda
self
,
dim
,
start
,
length
:
narrow
(
self
,
dim
,
start
,
length
)
self
,
dim
,
start
,
length
)
SparseTensor
.
__narrow_diag__
=
lambda
self
,
start
,
length
:
__narrow_diag__
(
self
,
start
,
length
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment