Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
411dcb76
Commit
411dcb76
authored
Jan 28, 2020
by
rusty1s
Browse files
select methods
parent
7f7036cd
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
81 additions
and
59 deletions
+81
-59
torch_sparse/__init__.py
torch_sparse/__init__.py
+2
-0
torch_sparse/narrow.py
torch_sparse/narrow.py
+39
-28
torch_sparse/select.py
torch_sparse/select.py
+11
-2
torch_sparse/tensor.py
torch_sparse/tensor.py
+1
-3
torch_sparse/transpose.py
torch_sparse/transpose.py
+28
-26
No files found.
torch_sparse/__init__.py
View file @
411dcb76
...
...
@@ -23,3 +23,5 @@ __all__ = [
from
.storage
import
SparseStorage
from
.tensor
import
SparseTensor
from
.transpose
import
t
from
.narrow
import
narrow
from
.select
import
select
torch_sparse/narrow.py
View file @
411dcb76
import
torch
from
torch_sparse.storage
import
SparseStorage
from
torch_sparse.tensor
import
SparseTensor
def
narrow
(
src
,
dim
,
start
,
length
):
@
torch
.
jit
.
script
def
narrow
(
src
:
SparseTensor
,
dim
:
int
,
start
:
int
,
length
:
int
):
dim
=
src
.
dim
()
+
dim
if
dim
<
0
else
dim
start
=
src
.
size
(
dim
)
+
start
if
start
<
0
else
start
if
dim
==
0
:
rowptr
,
col
,
value
=
src
.
csr
()
# rowptr = src.storage.rowptr
# Maintain `rowcount`...
rowcount
=
src
.
storage
.
_rowcount
if
rowcount
is
not
None
:
rowcount
=
rowcount
.
narrow
(
0
,
start
=
start
,
length
=
length
)
rowptr
=
rowptr
.
narrow
(
0
,
start
=
start
,
length
=
length
+
1
)
row_start
=
rowptr
[
0
]
...
...
@@ -22,46 +19,60 @@ def narrow(src, dim, start, length):
row
=
src
.
storage
.
_row
if
row
is
not
None
:
row
=
row
.
narrow
(
0
,
row_start
,
row_length
)
-
start
col
=
col
.
narrow
(
0
,
row_start
,
row_length
)
if
src
.
has_value
()
:
if
value
is
not
None
:
value
=
value
.
narrow
(
0
,
row_start
,
row_length
)
sparse_size
=
torch
.
Size
([
length
,
src
.
sparse_size
(
1
)])
sparse_size
s
=
torch
.
Size
([
length
,
src
.
sparse_size
(
1
)])
storage
=
src
.
storage
.
__class__
(
row
=
row
,
rowptr
=
rowptr
,
col
=
col
,
value
=
value
,
sparse_size
=
sparse_size
,
rowcount
=
rowcount
,
is_sorted
=
True
)
rowcount
=
src
.
storage
.
_rowcount
if
rowcount
is
not
None
:
rowcount
=
rowcount
.
narrow
(
0
,
start
=
start
,
length
=
length
)
storage
=
SparseStorage
(
row
=
row
,
rowptr
=
rowptr
,
col
=
col
,
value
=
value
,
sparse_sizes
=
sparse_sizes
,
rowcount
=
rowcount
,
colptr
=
None
,
colcount
=
None
,
csr2csc
=
None
,
csc2csr
=
None
,
is_sorted
=
True
)
return
src
.
from_storage
(
storage
)
elif
dim
==
1
:
# This is faster than accessing `csc()` contrary to the `dim=0` case.
row
,
col
,
value
=
src
.
coo
()
mask
=
(
col
>=
start
)
&
(
col
<
start
+
length
)
row
,
col
=
row
[
mask
],
col
[
mask
]
-
start
row
=
row
[
mask
]
col
=
col
[
mask
]
-
start
# Maintain `colcount`...
colcount
=
src
.
storage
.
_colcount
if
colcount
is
not
None
:
colcount
=
colcount
.
narrow
(
0
,
start
=
start
,
length
=
length
)
if
value
is
not
None
:
value
=
value
[
mask
]
sparse_sizes
=
torch
.
Size
([
src
.
sparse_size
(
0
)
,
length
]
)
# Maintain `colptr`...
colptr
=
src
.
storage
.
_colptr
if
colptr
is
not
None
:
colptr
=
colptr
.
narrow
(
0
,
start
=
start
,
length
=
length
+
1
)
colptr
=
colptr
-
colptr
[
0
]
if
src
.
has_value
():
value
=
value
[
mask
]
sparse_size
=
torch
.
Size
([
src
.
sparse_size
(
0
),
length
])
colcount
=
src
.
storage
.
_colcount
if
colcount
is
not
None
:
colcount
=
colcount
.
narrow
(
0
,
start
=
start
,
length
=
length
)
storage
=
src
.
storage
.
__class__
(
row
=
row
,
col
=
col
,
value
=
value
,
sparse_size
=
sparse_size
,
colptr
=
colptr
,
colcount
=
colcount
,
is_sorted
=
True
)
storage
=
SparseStorage
(
row
=
row
,
rowptr
=
None
,
col
=
col
,
value
=
value
,
sparse_sizes
=
sparse_sizes
,
rowcount
=
None
,
colptr
=
colptr
,
colcount
=
colcount
,
csr2csc
=
None
,
csc2csr
=
None
,
is_sorted
=
True
)
return
src
.
from_storage
(
storage
)
else
:
storage
=
src
.
storage
.
apply_value
(
lambda
x
:
x
.
narrow
(
dim
-
1
,
start
,
length
))
value
=
src
.
storage
.
value
()
if
value
is
not
None
:
return
src
.
set_value
(
value
.
narrow
(
dim
-
1
,
start
,
length
),
layout
=
'coo'
)
else
:
raise
ValueError
return
src
.
from_storage
(
storage
)
SparseTensor
.
narrow
=
lambda
self
,
dim
,
start
,
length
:
narrow
(
self
,
dim
,
start
,
length
)
torch_sparse/select.py
View file @
411dcb76
def
select
(
src
,
dim
,
idx
):
return
src
.
narrow
(
dim
,
start
=
idx
,
length
=
1
)
import
torch
from
torch_sparse.tensor
import
SparseTensor
from
torch_sparse.narrow
import
narrow
@
torch
.
jit
.
script
def
select
(
src
:
SparseTensor
,
dim
:
int
,
idx
:
int
):
return
narrow
(
src
,
dim
,
start
=
idx
,
length
=
1
)
SparseTensor
.
select
=
lambda
self
,
dim
,
idx
:
select
(
self
,
dim
,
idx
)
torch_sparse/tensor.py
View file @
411dcb76
...
...
@@ -6,12 +6,10 @@ import scipy.sparse
from
torch_sparse.storage
import
SparseStorage
,
get_layout
# from torch_sparse.narrow import narrow
# from torch_sparse.select import select
# from torch_sparse.index_select import index_select, index_select_nnz
# from torch_sparse.masked_select import masked_select, masked_select_nnz
# import torch_sparse.reduce
# from torch_sparse.diag import remove_diag, set_diag
# import torch_sparse.reduce
# from torch_sparse.matmul import matmul
# from torch_sparse.add import add, add_, add_nnz, add_nnz_
# from torch_sparse.mul import mul, mul_, mul_nnz, mul_nnz_
...
...
torch_sparse/transpose.py
View file @
411dcb76
...
...
@@ -5,32 +5,6 @@ from torch_sparse.storage import SparseStorage
from
torch_sparse.tensor
import
SparseTensor
def
transpose
(
index
,
value
,
m
,
n
,
coalesced
=
True
):
"""Transposes dimensions 0 and 1 of a sparse tensor.
Args:
index (:class:`LongTensor`): The index tensor of sparse matrix.
value (:class:`Tensor`): The value tensor of sparse matrix.
m (int): The first dimension of corresponding dense matrix.
n (int): The second dimension of corresponding dense matrix.
coalesced (bool, optional): If set to :obj:`False`, will not coalesce
the output. (default: :obj:`True`)
:rtype: (:class:`LongTensor`, :class:`Tensor`)
"""
if
value
.
dim
()
==
1
and
not
value
.
is_cuda
:
mat
=
to_scipy
(
index
,
value
,
m
,
n
).
tocsc
()
(
col
,
row
),
value
=
from_scipy
(
mat
)
index
=
torch
.
stack
([
row
,
col
],
dim
=
0
)
return
index
,
value
row
,
col
=
index
index
=
torch
.
stack
([
col
,
row
],
dim
=
0
)
if
coalesced
:
index
,
value
=
coalesce
(
index
,
value
,
n
,
m
)
return
index
,
value
@
torch
.
jit
.
script
def
t
(
src
:
SparseTensor
):
csr2csc
=
src
.
storage
.
csr2csc
()
...
...
@@ -60,3 +34,31 @@ def t(src: SparseTensor):
SparseTensor
.
t
=
lambda
self
:
t
(
self
)
###############################################################################
def
transpose
(
index
,
value
,
m
,
n
,
coalesced
=
True
):
"""Transposes dimensions 0 and 1 of a sparse tensor.
Args:
index (:class:`LongTensor`): The index tensor of sparse matrix.
value (:class:`Tensor`): The value tensor of sparse matrix.
m (int): The first dimension of corresponding dense matrix.
n (int): The second dimension of corresponding dense matrix.
coalesced (bool, optional): If set to :obj:`False`, will not coalesce
the output. (default: :obj:`True`)
:rtype: (:class:`LongTensor`, :class:`Tensor`)
"""
if
value
.
dim
()
==
1
and
not
value
.
is_cuda
:
mat
=
to_scipy
(
index
,
value
,
m
,
n
).
tocsc
()
(
col
,
row
),
value
=
from_scipy
(
mat
)
index
=
torch
.
stack
([
row
,
col
],
dim
=
0
)
return
index
,
value
row
,
col
=
index
index
=
torch
.
stack
([
col
,
row
],
dim
=
0
)
if
coalesced
:
index
,
value
=
coalesce
(
index
,
value
,
n
,
m
)
return
index
,
value
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment