Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
4a68dd60
Commit
4a68dd60
authored
Jan 29, 2020
by
rusty1s
Browse files
add mul diag traceable
parent
f6bf81df
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
280 additions
and
331 deletions
+280
-331
torch_sparse/__init__.py
torch_sparse/__init__.py
+5
-0
torch_sparse/add.py
torch_sparse/add.py
+55
-114
torch_sparse/diag.py
torch_sparse/diag.py
+52
-43
torch_sparse/index_select.py
torch_sparse/index_select.py
+48
-33
torch_sparse/masked_select.py
torch_sparse/masked_select.py
+49
-33
torch_sparse/mul.py
torch_sparse/mul.py
+67
-83
torch_sparse/narrow.py
torch_sparse/narrow.py
+2
-1
torch_sparse/select.py
torch_sparse/select.py
+1
-1
torch_sparse/tensor.py
torch_sparse/tensor.py
+0
-22
torch_sparse/transpose.py
torch_sparse/transpose.py
+1
-1
No files found.
torch_sparse/__init__.py
View file @
4a68dd60
...
@@ -40,3 +40,8 @@ from .tensor import SparseTensor
...
@@ -40,3 +40,8 @@ from .tensor import SparseTensor
from
.transpose
import
t
from
.transpose
import
t
from
.narrow
import
narrow
from
.narrow
import
narrow
from
.select
import
select
from
.select
import
select
from
.index_select
import
index_select
,
index_select_nnz
from
.masked_select
import
masked_select
,
masked_select_nnz
from
.diag
import
set_diag
,
remove_diag
from
.add
import
add
,
add_
,
add_nnz
,
add_nnz_
from
.mul
import
mul
,
mul_
,
mul_nnz
,
mul_nnz_
torch_sparse/add.py
View file @
4a68dd60
from
typing
import
Optional
import
torch
import
torch
from
torch_scatter
import
gather_csr
from
torch_scatter
import
gather_csr
from
torch_sparse.utils
import
is_scalar
from
torch_sparse.tensor
import
SparseTensor
def
sparse_add
(
matA
,
matB
):
nnzA
,
nnzB
=
matA
.
nnz
(),
matB
.
nnz
()
valA
=
torch
.
full
((
nnzA
,
),
1
,
dtype
=
torch
.
uint8
,
device
=
matA
.
device
)
valB
=
torch
.
full
((
nnzB
,
),
2
,
dtype
=
torch
.
uint8
,
device
=
matB
.
device
)
if
matA
.
is_cuda
:
pass
else
:
matA_
=
matA
.
set_value
(
valA
,
layout
=
'csr'
).
to_scipy
(
layout
=
'csr'
)
matB_
=
matB
.
set_value
(
valB
,
layout
=
'csr'
).
to_scipy
(
layout
=
'csr'
)
matC_
=
matA_
+
matB_
rowptr
=
torch
.
from_numpy
(
matC_
.
indptr
).
to
(
torch
.
long
)
matC_
=
matC_
.
tocoo
()
row
=
torch
.
from_numpy
(
matC_
.
row
).
to
(
torch
.
long
)
col
=
torch
.
from_numpy
(
matC_
.
col
).
to
(
torch
.
long
)
index
=
torch
.
stack
([
row
,
col
],
dim
=
0
)
valC_
=
torch
.
from_numpy
(
matC_
.
data
)
value
=
None
if
matA
.
has_value
()
or
matB
.
has_value
():
maskA
,
maskB
=
valC_
!=
2
,
valC_
>=
2
size
=
matA
.
size
()
if
matA
.
dim
()
>=
matB
.
dim
()
else
matA
.
size
()
size
=
(
valC_
.
size
(
0
),
)
+
size
[
2
:]
value
=
torch
.
zeros
(
size
,
dtype
=
matA
.
dtype
,
device
=
matA
.
device
)
value
[
maskA
]
+=
matA
.
storage
.
value
if
matA
.
has_value
()
else
1
value
[
maskB
]
+=
matB
.
storage
.
value
if
matB
.
has_value
()
else
1
storage
=
matA
.
storage
.
__class__
(
index
,
value
,
matA
.
sparse_size
(),
rowptr
=
rowptr
,
is_sorted
=
True
)
return
matA
.
__class__
.
from_storage
(
storage
)
def
add
(
src
,
other
):
@
torch
.
jit
.
script
if
is_scalar
(
other
):
def
add
(
src
:
SparseTensor
,
other
:
torch
.
Tensor
)
->
SparseTensor
:
return
add_nnz
(
src
,
other
)
elif
torch
.
is_tensor
(
other
):
rowptr
,
col
,
value
=
src
.
csr
()
rowptr
,
col
,
value
=
src
.
csr
()
if
other
.
size
(
0
)
==
src
.
size
(
0
)
and
other
.
size
(
1
)
==
1
:
# Row-wise...
if
other
.
size
(
0
)
==
src
.
size
(
0
)
and
other
.
size
(
1
)
==
1
:
# Row-wise...
other
=
gather_csr
(
other
.
squeeze
(
1
),
rowptr
)
# TODO
value
=
other
.
add_
(
src
.
storage
.
value
if
src
.
has_value
()
else
1
)
# other = gather_csr(other.squeeze(1), rowptr)
return
src
.
set_value
(
value
,
layout
=
'csr'
)
pass
elif
other
.
size
(
0
)
==
1
and
other
.
size
(
1
)
==
src
.
size
(
1
):
# Col-wise...
if
other
.
size
(
0
)
==
1
and
other
.
size
(
1
)
==
src
.
size
(
1
):
# Col-wise...
other
=
other
.
squeeze
(
0
)[
col
]
other
=
other
.
squeeze
(
0
)[
col
]
value
=
other
.
add_
(
src
.
storage
.
value
if
src
.
has_value
()
else
1
)
else
:
return
src
.
set_value
(
value
,
layout
=
'coo'
)
raise
ValueError
(
f
'Size mismatch: Expected size (
{
src
.
size
(
0
)
}
, 1,'
raise
ValueError
(
f
'Size mismatch: Expected size (
{
src
.
size
(
0
)
}
, 1,'
f
' ...) or (1,
{
src
.
size
(
1
)
}
, ...), but got size '
f
' ...) or (1,
{
src
.
size
(
1
)
}
, ...), but got size '
f
'
{
other
.
size
()
}
.'
)
f
'
{
other
.
size
()
}
.'
)
elif
isinstance
(
other
,
src
.
__class__
):
if
value
is
not
None
:
raise
NotImplementedError
value
=
other
.
add_
(
value
)
else
:
raise
ValueError
(
'Argument `other` needs to be of type `int`, `float`, '
value
=
other
.
add_
(
1
)
'`torch.tensor` or `torch_sparse.SparseTensor`.'
)
return
src
.
set_value
(
value
,
layout
=
'coo'
)
def
add_
(
src
,
other
):
if
is_scalar
(
other
):
return
add_nnz_
(
src
,
other
)
elif
torch
.
is_tensor
(
other
):
@
torch
.
jit
.
script
def
add_
(
src
:
SparseTensor
,
other
:
torch
.
Tensor
)
->
SparseTensor
:
rowptr
,
col
,
value
=
src
.
csr
()
rowptr
,
col
,
value
=
src
.
csr
()
if
other
.
size
(
0
)
==
src
.
size
(
0
)
and
other
.
size
(
1
)
==
1
:
# Row-wise...
if
other
.
size
(
0
)
==
src
.
size
(
0
)
and
other
.
size
(
1
)
==
1
:
# Row-wise...
other
=
gather_csr
(
other
.
squeeze
(
1
),
rowptr
)
# TODO
if
src
.
has_value
():
# other = gather_csr(other.squeeze(1), rowptr)
value
=
src
.
storage
.
value
.
add_
(
other
)
pass
else
:
elif
other
.
size
(
0
)
==
1
and
other
.
size
(
1
)
==
src
.
size
(
1
):
# Col-wise...
value
=
other
.
add_
(
1
)
return
src
.
set_value_
(
value
,
layout
=
'csr'
)
if
other
.
size
(
0
)
==
1
and
other
.
size
(
1
)
==
src
.
size
(
1
):
# Col-wise...
other
=
other
.
squeeze
(
0
)[
col
]
other
=
other
.
squeeze
(
0
)[
col
]
if
src
.
has_value
():
value
=
src
.
storage
.
value
.
add_
(
other
)
else
:
else
:
value
=
other
.
add_
(
1
)
return
src
.
set_value_
(
value
,
layout
=
'coo'
)
raise
ValueError
(
f
'Size mismatch: Expected size (
{
src
.
size
(
0
)
}
, 1,'
raise
ValueError
(
f
'Size mismatch: Expected size (
{
src
.
size
(
0
)
}
, 1,'
f
' ...) or (1,
{
src
.
size
(
1
)
}
, ...), but got size '
f
' ...) or (1,
{
src
.
size
(
1
)
}
, ...), but got size '
f
'
{
other
.
size
()
}
.'
)
f
'
{
other
.
size
()
}
.'
)
elif
isinstance
(
other
,
src
.
__class__
):
if
value
is
not
None
:
raise
NotImplementedError
value
=
value
.
add_
(
other
)
raise
ValueError
(
'Argument `other` needs to be of type `int`, `float`, '
'`torch.tensor` or `torch_sparse.SparseTensor`.'
)
def
add_nnz
(
src
,
other
,
layout
=
None
):
if
is_scalar
(
other
):
if
src
.
has_value
():
value
=
src
.
storage
.
value
+
other
else
:
else
:
value
=
torch
.
full
((
src
.
nnz
(),
),
1
+
other
,
device
=
src
.
device
)
value
=
other
.
add_
(
1
)
return
src
.
set_value
(
value
,
layout
=
'coo'
)
return
src
.
set_value
_
(
value
,
layout
=
'coo'
)
if
torch
.
is_tensor
(
other
):
if
src
.
has_value
():
value
=
src
.
storage
.
value
+
other
else
:
value
=
other
+
1
return
src
.
set_value
(
value
,
layout
=
'coo'
)
raise
ValueError
(
'Argument `other` needs to be of type `int`, `float` or '
@
torch
.
jit
.
script
'`torch.tensor`.'
)
def
add_nnz
(
src
:
SparseTensor
,
other
:
torch
.
Tensor
,
layout
:
Optional
[
str
]
=
None
)
->
SparseTensor
:
value
=
src
.
storage
.
value
()
if
value
is
not
None
:
value
=
value
.
add
(
other
)
else
:
value
=
other
.
add
(
1
)
return
src
.
set_value
(
value
,
layout
=
layout
)
def
add_nnz_
(
src
,
other
,
layout
=
None
):
@
torch
.
jit
.
script
if
is_scalar
(
other
):
def
add_nnz_
(
src
:
SparseTensor
,
other
:
torch
.
Tensor
,
if
src
.
has_value
():
layout
:
Optional
[
str
]
=
None
)
->
SparseTensor
:
value
=
src
.
storage
.
value
.
add_
(
other
)
value
=
src
.
storage
.
value
()
if
value
is
not
None
:
value
=
value
.
add_
(
other
)
else
:
else
:
value
=
torch
.
full
((
src
.
nnz
(),
),
1
+
other
,
device
=
src
.
device
)
value
=
other
.
add
(
1
)
return
src
.
set_value_
(
value
,
layout
=
'coo'
)
return
src
.
set_value_
(
value
,
layout
=
layout
)
if
torch
.
is_tensor
(
other
):
if
src
.
has_value
():
value
=
src
.
storage
.
value
.
add_
(
other
)
else
:
value
=
other
+
1
# No inplace operation possible.
return
src
.
set_value_
(
value
,
layout
=
'coo'
)
raise
ValueError
(
'Argument `other` needs to be of type `int`, `float` or '
SparseTensor
.
add
=
lambda
self
,
other
:
add
(
self
,
other
)
'`torch.tensor`.'
)
SparseTensor
.
add_
=
lambda
self
,
other
:
add_
(
self
,
other
)
SparseTensor
.
add_nnz
=
lambda
self
,
other
,
layout
=
None
:
add_nnz
(
self
,
other
,
layout
)
SparseTensor
.
add_nnz_
=
lambda
self
,
other
,
layout
=
None
:
add_nnz_
(
self
,
other
,
layout
)
torch_sparse/diag.py
View file @
4a68dd60
from
typing
import
Optional
import
torch
import
torch
from
torch_sparse.storage
import
SparseStorage
from
torch_sparse.tensor
import
SparseTensor
def
remove_diag
(
src
,
k
=
0
):
@
torch
.
jit
.
script
def
remove_diag
(
src
:
SparseTensor
,
k
:
int
=
0
)
->
SparseTensor
:
row
,
col
,
value
=
src
.
coo
()
row
,
col
,
value
=
src
.
coo
()
inv_mask
=
row
!=
col
if
k
==
0
else
row
!=
(
col
-
k
)
inv_mask
=
row
!=
col
if
k
==
0
else
row
!=
(
col
-
k
)
new_row
,
new_col
=
row
[
inv_mask
],
col
[
inv_mask
]
new_row
,
new_col
=
row
[
inv_mask
],
col
[
inv_mask
]
if
src
.
has_value
()
:
if
value
is
not
None
:
value
=
value
[
inv_mask
]
value
=
value
[
inv_mask
]
if
src
.
storage
.
has_rowcount
()
or
src
.
storage
.
has_colcount
():
rowcount
=
src
.
storage
.
_rowcount
colcount
=
src
.
storage
.
_colcount
if
rowcount
is
not
None
or
colcount
is
not
None
:
mask
=
~
inv_mask
mask
=
~
inv_mask
if
rowcount
is
not
None
:
rowcount
=
None
rowcount
=
rowcount
.
clone
()
if
src
.
storage
.
has_rowcount
():
rowcount
=
src
.
storage
.
rowcount
.
clone
()
rowcount
[
row
[
mask
]]
-=
1
rowcount
[
row
[
mask
]]
-=
1
if
colcount
is
not
None
:
colcount
=
None
colcount
=
colcount
.
clone
()
if
src
.
storage
.
has_colcount
():
colcount
=
src
.
storage
.
colcount
.
clone
()
colcount
[
col
[
mask
]]
-=
1
colcount
[
col
[
mask
]]
-=
1
storage
=
src
.
storage
.
__class__
(
row
=
new_row
,
col
=
new_col
,
value
=
value
,
storage
=
SparseStorage
(
row
=
new_row
,
rowptr
=
None
,
col
=
new_col
,
value
=
value
,
sparse_size
=
src
.
sparse_size
(),
sparse_sizes
=
src
.
sparse_sizes
(),
rowcount
=
rowcount
,
rowcount
=
rowcount
,
colcount
=
colcount
,
colptr
=
None
,
colcount
=
colcount
,
csr2csc
=
None
,
is_sorted
=
True
)
csc2csr
=
None
,
is_sorted
=
True
)
return
src
.
__class__
.
from_storage
(
storage
)
return
src
.
from_storage
(
storage
)
def
set_diag
(
src
,
values
=
None
,
k
=
0
):
if
values
is
not
None
and
not
src
.
has_value
():
raise
ValueError
(
'Sparse matrix has no values'
)
src
=
src
.
remove_diag
(
k
=
0
)
@
torch
.
jit
.
script
def
set_diag
(
src
:
SparseTensor
,
values
:
Optional
[
torch
.
Tensor
]
=
None
,
k
:
int
=
0
)
->
SparseTensor
:
src
=
remove_diag
(
src
,
k
=
0
)
row
,
col
,
value
=
src
.
coo
()
row
,
col
,
value
=
src
.
coo
()
if
row
.
is_cuda
:
if
row
.
is_cuda
:
...
@@ -47,7 +48,7 @@ def set_diag(src, values=None, k=0):
...
@@ -47,7 +48,7 @@ def set_diag(src, values=None, k=0):
inv_mask
=
~
mask
inv_mask
=
~
mask
start
,
num_diag
=
-
k
if
k
<
0
else
0
,
mask
.
numel
()
-
row
.
numel
()
start
,
num_diag
=
-
k
if
k
<
0
else
0
,
mask
.
numel
()
-
row
.
numel
()
diag
=
torch
.
arange
(
start
,
start
+
num_diag
,
device
=
src
.
device
)
diag
=
torch
.
arange
(
start
,
start
+
num_diag
,
device
=
row
.
device
)
new_row
=
row
.
new_empty
(
mask
.
size
(
0
))
new_row
=
row
.
new_empty
(
mask
.
size
(
0
))
new_row
[
mask
]
=
row
new_row
[
mask
]
=
row
...
@@ -57,25 +58,33 @@ def set_diag(src, values=None, k=0):
...
@@ -57,25 +58,33 @@ def set_diag(src, values=None, k=0):
new_col
[
mask
]
=
row
new_col
[
mask
]
=
row
new_col
[
inv_mask
]
=
diag
.
add_
(
k
)
new_col
[
inv_mask
]
=
diag
.
add_
(
k
)
new_value
=
None
new_value
:
Optional
[
torch
.
Tensor
]
=
None
if
src
.
has_value
()
:
if
value
is
not
None
:
new_value
=
value
.
new_empty
((
mask
.
size
(
0
),
)
+
value
.
size
()[
1
:])
new_value
=
value
.
new_empty
((
mask
.
size
(
0
),
)
+
value
.
size
()[
1
:])
new_value
[
mask
]
=
value
new_value
[
mask
]
=
value
new_value
[
inv_mask
]
=
values
if
values
is
not
None
else
1
if
values
is
not
None
:
new_value
[
inv_mask
]
=
values
else
:
new_value
[
inv_mask
]
=
torch
.
ones
((
num_diag
,
),
dtype
=
value
.
dtype
,
device
=
value
.
device
)
rowcount
=
None
rowcount
=
src
.
storage
.
_rowcount
if
src
.
storage
.
has_rowcount
()
:
if
rowcount
is
not
None
:
rowcount
=
src
.
storage
.
rowcount
.
clone
()
rowcount
=
rowcount
.
clone
()
rowcount
[
start
:
start
+
num_diag
]
+=
1
rowcount
[
start
:
start
+
num_diag
]
+=
1
colcount
=
None
colcount
=
src
.
storage
.
_colcount
if
src
.
storage
.
has_colcount
()
:
if
colcount
is
not
None
:
colcount
=
src
.
storage
.
colcount
.
clone
()
colcount
=
colcount
.
clone
()
colcount
[
start
+
k
:
start
+
num_diag
+
k
]
+=
1
colcount
[
start
+
k
:
start
+
num_diag
+
k
]
+=
1
storage
=
src
.
storage
.
__class__
(
row
=
new_row
,
col
=
new_col
,
value
=
new_value
,
storage
=
SparseStorage
(
row
=
new_row
,
rowptr
=
None
,
col
=
new_col
,
sparse_size
=
src
.
sparse_size
(),
value
=
new_value
,
sparse_sizes
=
src
.
sparse_sizes
(),
rowcount
=
rowcount
,
colcount
=
colcount
,
rowcount
=
rowcount
,
colptr
=
None
,
colcount
=
colcount
,
is_sorted
=
True
)
csr2csc
=
None
,
csc2csr
=
None
,
is_sorted
=
True
)
return
src
.
from_storage
(
storage
)
return
src
.
__class__
.
from_storage
(
storage
)
SparseTensor
.
remove_diag
=
lambda
self
,
k
=
0
:
remove_diag
(
self
,
k
)
SparseTensor
.
set_diag
=
lambda
self
,
values
=
None
,
k
=
0
:
set_diag
(
self
,
values
,
k
)
torch_sparse/index_select.py
View file @
4a68dd60
from
typing
import
Optional
import
torch
import
torch
from
torch_scatter
import
gather_csr
from
torch_scatter
import
gather_csr
from
torch_sparse.storage
import
SparseStorage
,
get_layout
from
torch_sparse.
storage
import
get_layout
from
torch_sparse.
tensor
import
SparseTensor
def
index_select
(
src
,
dim
,
idx
):
@
torch
.
jit
.
script
def
index_select
(
src
:
SparseTensor
,
dim
:
int
,
idx
:
torch
.
Tensor
)
->
SparseTensor
:
dim
=
src
.
dim
()
+
dim
if
dim
<
0
else
dim
dim
=
src
.
dim
()
+
dim
if
dim
<
0
else
dim
assert
idx
.
dim
()
==
1
assert
idx
.
dim
()
==
1
if
dim
==
0
:
if
dim
==
0
:
old_rowptr
,
col
,
value
=
src
.
csr
()
old_rowptr
,
col
,
value
=
src
.
csr
()
rowcount
=
src
.
storage
.
rowcount
rowcount
=
src
.
storage
.
rowcount
()
rowcount
=
rowcount
[
idx
]
rowcount
=
rowcount
[
idx
]
...
@@ -22,69 +25,81 @@ def index_select(src, dim, idx):
...
@@ -22,69 +25,81 @@ def index_select(src, dim, idx):
device
=
col
.
device
).
repeat_interleave
(
rowcount
)
device
=
col
.
device
).
repeat_interleave
(
rowcount
)
perm
=
torch
.
arange
(
row
.
size
(
0
),
device
=
row
.
device
)
perm
=
torch
.
arange
(
row
.
size
(
0
),
device
=
row
.
device
)
perm
+=
gather_csr
(
old_rowptr
[
idx
]
-
rowptr
[:
-
1
],
rowptr
)
# TODO
# perm += gather_csr(old_rowptr[idx] - rowptr[:-1], rowptr)
col
=
col
[
perm
]
col
=
col
[
perm
]
if
src
.
has_value
()
:
if
value
is
not
None
:
value
=
value
[
perm
]
value
=
value
[
perm
]
sparse_size
=
torch
.
Size
([
idx
.
size
(
0
),
src
.
sparse_size
(
1
)])
sparse_size
s
=
torch
.
Size
([
idx
.
size
(
0
),
src
.
sparse_size
(
1
)])
storage
=
src
.
storage
.
__class__
(
row
=
row
,
rowptr
=
rowptr
,
col
=
col
,
storage
=
SparseStorage
(
row
=
row
,
rowptr
=
rowptr
,
col
=
col
,
value
=
value
,
value
=
value
,
sparse_size
=
sparse_size
,
sparse_sizes
=
sparse_sizes
,
rowcount
=
rowcount
,
rowcount
=
rowcount
,
is_sorted
=
True
)
colptr
=
None
,
colcount
=
None
,
csr2csc
=
None
,
csc2csr
=
None
,
is_sorted
=
True
)
return
src
.
from_storage
(
storage
)
elif
dim
==
1
:
elif
dim
==
1
:
old_colptr
,
row
,
value
=
src
.
csc
()
old_colptr
,
row
,
value
=
src
.
csc
()
colcount
=
src
.
storage
.
colcount
colcount
=
src
.
storage
.
colcount
()
colcount
=
colcount
[
idx
]
colcount
=
colcount
[
idx
]
col
=
torch
.
arange
(
idx
.
size
(
0
),
device
=
row
.
device
).
repeat_interleave
(
colcount
)
colptr
=
row
.
new_zeros
(
idx
.
size
(
0
)
+
1
)
colptr
=
row
.
new_zeros
(
idx
.
size
(
0
)
+
1
)
torch
.
cumsum
(
colcount
,
dim
=
0
,
out
=
colptr
[
1
:])
torch
.
cumsum
(
colcount
,
dim
=
0
,
out
=
colptr
[
1
:])
col
=
torch
.
arange
(
idx
.
size
(
0
),
device
=
row
.
device
).
repeat_interleave
(
colcount
)
perm
=
torch
.
arange
(
col
.
size
(
0
),
device
=
col
.
device
)
perm
=
torch
.
arange
(
col
.
size
(
0
),
device
=
col
.
device
)
perm
+=
gather_csr
(
old_colptr
[
idx
]
-
colptr
[:
-
1
],
colptr
)
# TODO
# perm += gather_csr(old_colptr[idx] - colptr[:-1], colptr)
row
=
row
[
perm
]
row
=
row
[
perm
]
csc2csr
=
(
idx
.
size
(
0
)
*
row
+
col
).
argsort
()
csc2csr
=
(
idx
.
size
(
0
)
*
row
+
col
).
argsort
()
row
,
col
=
row
[
csc2csr
],
col
[
csc2csr
]
row
,
col
=
row
[
csc2csr
],
col
[
csc2csr
]
if
src
.
has_value
()
:
if
value
is
not
None
:
value
=
value
[
perm
][
csc2csr
]
value
=
value
[
perm
][
csc2csr
]
sparse_size
=
torch
.
Size
([
src
.
sparse_size
(
0
),
idx
.
size
(
0
)])
sparse_size
s
=
torch
.
Size
([
src
.
sparse_size
(
0
),
idx
.
size
(
0
)])
storage
=
src
.
storage
.
__class__
(
row
=
row
,
col
=
col
,
value
=
value
,
storage
=
SparseStorage
(
row
=
row
,
rowptr
=
None
,
col
=
col
,
value
=
value
,
sparse_size
=
sparse_size
,
colptr
=
colptr
,
sparse_sizes
=
sparse_sizes
,
rowcount
=
None
,
colcount
=
colcount
,
csc2csr
=
csc2csr
,
colptr
=
colptr
,
colcount
=
colcount
,
csr2csc
=
None
,
is_sorted
=
True
)
csc2csr
=
csc2csr
,
is_sorted
=
True
)
return
src
.
from_storage
(
storage
)
else
:
else
:
storage
=
src
.
storage
.
apply_value
(
value
=
src
.
storage
.
value
()
lambda
x
:
x
.
index_select
(
dim
-
1
,
idx
))
if
value
is
not
None
:
return
src
.
set_value
(
value
.
index_select
(
dim
-
1
,
idx
),
return
src
.
from_storage
(
storage
)
layout
=
'coo'
)
else
:
raise
ValueError
def
index_select_nnz
(
src
,
idx
,
layout
=
None
):
@
torch
.
jit
.
script
def
index_select_nnz
(
src
:
SparseTensor
,
idx
:
torch
.
Tensor
,
layout
:
Optional
[
str
]
=
None
)
->
SparseTensor
:
assert
idx
.
dim
()
==
1
assert
idx
.
dim
()
==
1
if
get_layout
(
layout
)
==
'csc'
:
if
get_layout
(
layout
)
==
'csc'
:
idx
=
idx
[
src
.
storage
.
csc2csr
]
idx
=
src
.
storage
.
csc2csr
()[
idx
]
row
,
col
,
value
=
src
.
coo
()
row
,
col
,
value
=
src
.
coo
()
row
,
col
=
row
[
idx
],
col
[
idx
]
row
,
col
=
row
[
idx
],
col
[
idx
]
if
src
.
has_value
()
:
if
value
is
not
None
:
value
=
value
[
idx
]
value
=
value
[
idx
]
# There is no other information we can maintain...
return
SparseTensor
(
row
=
row
,
rowptr
=
None
,
col
=
col
,
value
=
value
,
storage
=
src
.
storage
.
__class__
(
row
=
row
,
col
=
col
,
value
=
value
,
sparse_sizes
=
src
.
sparse_sizes
(),
is_sorted
=
True
)
sparse_size
=
src
.
sparse_size
(),
is_sorted
=
True
)
return
src
.
from_storage
(
storage
)
SparseTensor
.
index_select
=
lambda
self
,
dim
,
idx
:
index_select
(
self
,
dim
,
idx
)
tmp
=
lambda
self
,
idx
,
layout
=
None
:
index_select_nnz
(
# noqa
self
,
idx
,
layout
)
SparseTensor
.
index_select_nnz
=
tmp
torch_sparse/masked_select.py
View file @
4a68dd60
import
torch
from
typing
import
Optional
from
torch_sparse.storage
import
get_layout
import
torch
from
torch_sparse.storage
import
SparseStorage
,
get_layout
from
torch_sparse.tensor
import
SparseTensor
def
masked_select
(
src
,
dim
,
mask
):
@
torch
.
jit
.
script
def
masked_select
(
src
:
SparseTensor
,
dim
:
int
,
mask
:
torch
.
Tensor
)
->
SparseTensor
:
dim
=
src
.
dim
()
+
dim
if
dim
<
0
else
dim
dim
=
src
.
dim
()
+
dim
if
dim
<
0
else
dim
assert
mask
.
dim
()
==
1
assert
mask
.
dim
()
==
1
...
@@ -11,29 +15,33 @@ def masked_select(src, dim, mask):
...
@@ -11,29 +15,33 @@ def masked_select(src, dim, mask):
if
dim
==
0
:
if
dim
==
0
:
row
,
col
,
value
=
src
.
coo
()
row
,
col
,
value
=
src
.
coo
()
rowcount
=
src
.
storage
.
rowcount
rowcount
=
src
.
storage
.
rowcount
()
rowcount
=
rowcount
[
mask
]
rowcount
=
rowcount
[
mask
]
mask
=
mask
[
row
]
mask
=
mask
[
row
]
row
=
torch
.
arange
(
rowcount
.
size
(
0
),
row
=
torch
.
arange
(
rowcount
.
size
(
0
),
device
=
row
.
device
).
repeat_interleave
(
rowcount
)
device
=
row
.
device
).
repeat_interleave
(
rowcount
)
col
=
col
[
mask
]
col
=
col
[
mask
]
if
src
.
has_value
()
:
if
value
is
not
None
:
value
=
value
[
mask
]
value
=
value
[
mask
]
sparse_size
=
torch
.
Size
([
rowcount
.
size
(
0
),
src
.
sparse_size
(
1
)])
sparse_size
s
=
torch
.
Size
([
rowcount
.
size
(
0
),
src
.
sparse_size
(
1
)])
storage
=
src
.
storage
.
__class__
(
row
=
row
,
col
=
col
,
value
=
value
,
storage
=
SparseStorage
(
row
=
row
,
rowptr
=
None
,
col
=
col
,
value
=
value
,
sparse_size
=
sparse_size
,
sparse_sizes
=
sparse_sizes
,
rowcount
=
rowcount
,
rowcount
=
rowcount
,
is_sorted
=
True
)
colcount
=
None
,
colptr
=
None
,
csr2csc
=
None
,
csc2csr
=
None
,
is_sorted
=
True
)
return
src
.
from_storage
(
storage
)
elif
dim
==
1
:
elif
dim
==
1
:
row
,
col
,
value
=
src
.
coo
()
row
,
col
,
value
=
src
.
coo
()
csr2csc
=
src
.
storage
.
csr2csc
csr2csc
=
src
.
storage
.
csr2csc
()
row
,
col
=
row
[
csr2csc
],
col
[
csr2csc
]
row
=
row
[
csr2csc
]
colcount
=
src
.
storage
.
colcount
col
=
col
[
csr2csc
]
colcount
=
src
.
storage
.
colcount
()
colcount
=
colcount
[
mask
]
colcount
=
colcount
[
mask
]
...
@@ -44,39 +52,47 @@ def masked_select(src, dim, mask):
...
@@ -44,39 +52,47 @@ def masked_select(src, dim, mask):
csc2csr
=
(
colcount
.
size
(
0
)
*
row
+
col
).
argsort
()
csc2csr
=
(
colcount
.
size
(
0
)
*
row
+
col
).
argsort
()
row
,
col
=
row
[
csc2csr
],
col
[
csc2csr
]
row
,
col
=
row
[
csc2csr
],
col
[
csc2csr
]
if
src
.
has_value
()
:
if
value
is
not
None
:
value
=
value
[
csr2csc
][
mask
][
csc2csr
]
value
=
value
[
csr2csc
][
mask
][
csc2csr
]
sparse_size
=
torch
.
Size
([
src
.
sparse_size
(
0
),
colcount
.
size
(
0
)])
sparse_size
s
=
torch
.
Size
([
src
.
sparse_size
(
0
),
colcount
.
size
(
0
)])
storage
=
src
.
storage
.
__class__
(
row
=
row
,
col
=
col
,
value
=
value
,
storage
=
SparseStorage
(
row
=
row
,
rowptr
=
None
,
col
=
col
,
value
=
value
,
sparse_size
=
sparse_size
,
sparse_sizes
=
sparse_sizes
,
rowcount
=
None
,
colcount
=
colcount
,
csc2csr
=
csc2csr
,
colcount
=
colcount
,
colptr
=
None
,
csr2csc
=
None
,
is_sorted
=
True
)
csc2csr
=
csc2csr
,
is_sorted
=
True
)
return
src
.
from_storage
(
storage
)
else
:
else
:
idx
=
mask
.
nonzero
().
view
(
-
1
)
value
=
src
.
storage
.
value
()
storage
=
src
.
storage
.
apply_value
(
if
value
is
not
None
:
lambda
x
:
x
.
index_select
(
dim
-
1
,
idx
))
idx
=
mask
.
nonzero
().
flatten
()
return
src
.
set_value
(
value
.
index_select
(
dim
-
1
,
idx
),
return
src
.
from_storage
(
storage
)
layout
=
'coo'
)
else
:
raise
ValueError
def
masked_select_nnz
(
src
,
mask
,
layout
=
None
):
@
torch
.
jit
.
script
def
masked_select_nnz
(
src
:
SparseTensor
,
mask
:
torch
.
Tensor
,
layout
:
Optional
[
str
]
=
None
)
->
SparseTensor
:
assert
mask
.
dim
()
==
1
assert
mask
.
dim
()
==
1
if
get_layout
(
layout
)
==
'csc'
:
if
get_layout
(
layout
)
==
'csc'
:
mask
=
mask
[
src
.
storage
.
csc2csr
]
mask
=
mask
[
src
.
storage
.
csc2csr
()
]
row
,
col
,
value
=
src
.
coo
()
row
,
col
,
value
=
src
.
coo
()
row
,
col
=
row
[
mask
],
col
[
mask
]
row
,
col
=
row
[
mask
],
col
[
mask
]
if
src
.
has_value
()
:
if
value
is
not
None
:
value
=
value
[
mask
]
value
=
value
[
mask
]
# There is no other information we can maintain...
return
SparseTensor
(
row
=
row
,
rowptr
=
None
,
col
=
col
,
value
=
value
,
storage
=
src
.
storage
.
__class__
(
row
=
row
,
col
=
col
,
value
=
value
,
sparse_sizes
=
src
.
sparse_sizes
(),
is_sorted
=
True
)
sparse_size
=
src
.
sparse_size
(),
is_sorted
=
True
)
return
src
.
from_storage
(
storage
)
SparseTensor
.
masked_select
=
lambda
self
,
dim
,
mask
:
masked_select
(
self
,
dim
,
mask
)
tmp
=
lambda
src
,
mask
,
layout
=
None
:
masked_select_nnz
(
# noqa
src
,
mask
,
layout
)
SparseTensor
.
masked_select_nnz
=
tmp
torch_sparse/mul.py
View file @
4a68dd60
from
typing
import
Optional
import
torch
import
torch
from
torch_scatter
import
gather_csr
from
torch_scatter
import
gather_csr
from
torch_sparse.utils
import
is_scalar
from
torch_sparse.tensor
import
SparseTensor
def
mul
(
src
,
other
):
if
is_scalar
(
other
):
return
mul_nnz
(
src
,
other
)
elif
torch
.
is_tensor
(
other
):
@
torch
.
jit
.
script
def
mul
(
src
:
SparseTensor
,
other
:
torch
.
Tensor
)
->
SparseTensor
:
rowptr
,
col
,
value
=
src
.
csr
()
rowptr
,
col
,
value
=
src
.
csr
()
if
other
.
size
(
0
)
==
src
.
size
(
0
)
and
other
.
size
(
1
)
==
1
:
# Row-wise...
if
other
.
size
(
0
)
==
src
.
size
(
0
)
and
other
.
size
(
1
)
==
1
:
# Row-wise...
other
=
gather_csr
(
other
.
squeeze
(
1
),
rowptr
)
# TODO
if
src
.
has_value
():
# other = gather_csr(other.squeeze(1), rowptr)
value
=
other
.
mul_
(
src
.
storage
.
value
)
pass
else
:
elif
other
.
size
(
0
)
==
1
and
other
.
size
(
1
)
==
src
.
size
(
1
):
# Col-wise...
value
=
other
return
src
.
set_value
(
value
,
layout
=
'csr'
)
if
other
.
size
(
0
)
==
1
and
other
.
size
(
1
)
==
src
.
size
(
1
):
# Col-wise...
other
=
other
.
squeeze
(
0
)[
col
]
other
=
other
.
squeeze
(
0
)[
col
]
if
src
.
has_value
():
value
=
other
.
mul_
(
src
.
storage
.
value
)
else
:
else
:
value
=
other
return
src
.
set_value
(
value
,
layout
=
'coo'
)
raise
ValueError
(
f
'Size mismatch: Expected size (
{
src
.
size
(
0
)
}
, 1,'
raise
ValueError
(
f
'Size mismatch: Expected size (
{
src
.
size
(
0
)
}
, 1,'
f
' ...) or (1,
{
src
.
size
(
1
)
}
, ...), but got size '
f
' ...) or (1,
{
src
.
size
(
1
)
}
, ...), but got size '
f
'
{
other
.
size
()
}
.'
)
f
'
{
other
.
size
()
}
.'
)
elif
isinstance
(
other
,
src
.
__class__
):
if
value
is
not
None
:
raise
NotImplementedError
value
=
other
.
mul_
(
value
)
else
:
raise
ValueError
(
'Argument `other` needs to be of type `int`, `float`, '
value
=
other
'`torch.tensor` or `torch_sparse.SparseTensor`.'
)
return
src
.
set_value
(
value
,
layout
=
'coo'
)
def
mul_
(
src
,
other
):
if
is_scalar
(
other
):
return
mul_nnz_
(
src
,
other
)
elif
torch
.
is_tensor
(
other
):
@
torch
.
jit
.
script
def
mul_
(
src
:
SparseTensor
,
other
:
torch
.
Tensor
)
->
SparseTensor
:
rowptr
,
col
,
value
=
src
.
csr
()
rowptr
,
col
,
value
=
src
.
csr
()
if
other
.
size
(
0
)
==
src
.
size
(
0
)
and
other
.
size
(
1
)
==
1
:
# Row-wise...
if
other
.
size
(
0
)
==
src
.
size
(
0
)
and
other
.
size
(
1
)
==
1
:
# Row-wise...
other
=
gather_csr
(
other
.
squeeze
(
1
),
rowptr
)
# TODO
if
src
.
has_value
():
# other = gather_csr(other.squeeze(1), rowptr)
value
=
src
.
storage
.
value
.
mul_
(
other
)
pass
else
:
elif
other
.
size
(
0
)
==
1
and
other
.
size
(
1
)
==
src
.
size
(
1
):
# Col-wise...
value
=
other
return
src
.
set_value_
(
value
,
layout
=
'csr'
)
if
other
.
size
(
0
)
==
1
and
other
.
size
(
1
)
==
src
.
size
(
1
):
# Col-wise...
other
=
other
.
squeeze
(
0
)[
col
]
other
=
other
.
squeeze
(
0
)[
col
]
if
src
.
has_value
():
value
=
src
.
storage
.
value
.
mul_
(
other
)
else
:
else
:
value
=
other
return
src
.
set_value_
(
value
,
layout
=
'coo'
)
raise
ValueError
(
f
'Size mismatch: Expected size (
{
src
.
size
(
0
)
}
, 1,'
raise
ValueError
(
f
'Size mismatch: Expected size (
{
src
.
size
(
0
)
}
, 1,'
f
' ...) or (1,
{
src
.
size
(
1
)
}
, ...), but got size '
f
' ...) or (1,
{
src
.
size
(
1
)
}
, ...), but got size '
f
'
{
other
.
size
()
}
.'
)
f
'
{
other
.
size
()
}
.'
)
el
if
isinstance
(
other
,
src
.
__class__
)
:
if
value
is
not
None
:
raise
NotImplementedError
value
=
value
.
mul_
(
other
)
else
:
raise
ValueError
(
'Argument `other` needs to be of type `int`, `float`, '
value
=
other
'`torch.tensor` or `torch_sparse.SparseTensor`.
'
)
return
src
.
set_value_
(
value
,
layout
=
'coo
'
)
def
mul_nnz
(
src
,
other
,
layout
=
None
):
@
torch
.
jit
.
script
if
torch
.
is_tensor
(
other
)
or
is_scalar
(
other
):
def
mul_nnz
(
src
:
SparseTensor
,
other
:
torch
.
Tensor
,
if
src
.
has_value
():
layout
:
Optional
[
str
]
=
None
)
->
SparseTensor
:
value
=
src
.
storage
.
value
*
other
value
=
src
.
storage
.
value
()
if
value
is
not
None
:
value
=
value
.
mul
(
other
)
else
:
else
:
value
=
other
value
=
other
return
src
.
set_value
(
value
,
layout
=
'coo'
)
return
src
.
set_value
(
value
,
layout
=
layout
)
raise
ValueError
(
'Argument `other` needs to be of type `int`, `float` or '
'`torch.tensor`.'
)
@
torch
.
jit
.
script
def
mul_nnz_
(
src
,
other
,
layout
=
None
):
def
mul_nnz_
(
src
:
SparseTensor
,
other
:
torch
.
Tensor
,
if
torch
.
is_tensor
(
other
)
or
is_scalar
(
other
):
layout
:
Optional
[
str
]
=
None
)
->
SparseTensor
:
if
src
.
has_value
():
value
=
src
.
storage
.
value
()
value
=
src
.
storage
.
value
.
mul_
(
other
)
if
value
is
not
None
:
value
=
value
.
mul_
(
other
)
else
:
else
:
value
=
other
value
=
other
return
src
.
set_value_
(
value
,
layout
=
'coo'
)
return
src
.
set_value_
(
value
,
layout
=
layout
)
raise
ValueError
(
'Argument `other` needs to be of type `int`, `float` or '
SparseTensor
.
mul
=
lambda
self
,
other
:
mul
(
self
,
other
)
'`torch.tensor`.'
)
SparseTensor
.
mul_
=
lambda
self
,
other
:
mul_
(
self
,
other
)
SparseTensor
.
mul_nnz
=
lambda
self
,
other
,
layout
=
None
:
mul_nnz
(
self
,
other
,
layout
)
SparseTensor
.
mul_nnz_
=
lambda
self
,
other
,
layout
=
None
:
mul_nnz_
(
self
,
other
,
layout
)
torch_sparse/narrow.py
View file @
4a68dd60
...
@@ -4,7 +4,8 @@ from torch_sparse.tensor import SparseTensor
...
@@ -4,7 +4,8 @@ from torch_sparse.tensor import SparseTensor
@
torch
.
jit
.
script
@
torch
.
jit
.
script
def
narrow
(
src
:
SparseTensor
,
dim
:
int
,
start
:
int
,
length
:
int
):
def
narrow
(
src
:
SparseTensor
,
dim
:
int
,
start
:
int
,
length
:
int
)
->
SparseTensor
:
dim
=
src
.
dim
()
+
dim
if
dim
<
0
else
dim
dim
=
src
.
dim
()
+
dim
if
dim
<
0
else
dim
start
=
src
.
size
(
dim
)
+
start
if
start
<
0
else
start
start
=
src
.
size
(
dim
)
+
start
if
start
<
0
else
start
...
...
torch_sparse/select.py
View file @
4a68dd60
...
@@ -4,7 +4,7 @@ from torch_sparse.narrow import narrow
...
@@ -4,7 +4,7 @@ from torch_sparse.narrow import narrow
@
torch
.
jit
.
script
@
torch
.
jit
.
script
def
select
(
src
:
SparseTensor
,
dim
:
int
,
idx
:
int
):
def
select
(
src
:
SparseTensor
,
dim
:
int
,
idx
:
int
)
->
SparseTensor
:
return
narrow
(
src
,
dim
,
start
=
idx
,
length
=
1
)
return
narrow
(
src
,
dim
,
start
=
idx
,
length
=
1
)
...
...
torch_sparse/tensor.py
View file @
4a68dd60
...
@@ -5,14 +5,6 @@ import torch
...
@@ -5,14 +5,6 @@ import torch
import
scipy.sparse
import
scipy.sparse
from
torch_sparse.storage
import
SparseStorage
,
get_layout
from
torch_sparse.storage
import
SparseStorage
,
get_layout
# from torch_sparse.index_select import index_select, index_select_nnz
# from torch_sparse.masked_select import masked_select, masked_select_nnz
# from torch_sparse.diag import remove_diag, set_diag
# import torch_sparse.reduce
# from torch_sparse.matmul import matmul
# from torch_sparse.add import add, add_, add_nnz, add_nnz_
# from torch_sparse.mul import mul, mul_, mul_nnz, mul_nnz_
from
torch_sparse.utils
import
is_scalar
from
torch_sparse.utils
import
is_scalar
...
@@ -403,12 +395,6 @@ class SparseTensor(object):
...
@@ -403,12 +395,6 @@ class SparseTensor(object):
# return matmul(self, other, reduce='sum')
# return matmul(self, other, reduce='sum')
# SparseTensor.narrow = narrow
# SparseTensor.select = select
# SparseTensor.index_select = index_select
# SparseTensor.index_select_nnz = index_select_nnz
# SparseTensor.masked_select = masked_select
# SparseTensor.masked_select_nnz = masked_select_nnz
# SparseTensor.reduction = torch_sparse.reduce.reduction
# SparseTensor.reduction = torch_sparse.reduce.reduction
# SparseTensor.sum = torch_sparse.reduce.sum
# SparseTensor.sum = torch_sparse.reduce.sum
# SparseTensor.mean = torch_sparse.reduce.mean
# SparseTensor.mean = torch_sparse.reduce.mean
...
@@ -417,14 +403,6 @@ class SparseTensor(object):
...
@@ -417,14 +403,6 @@ class SparseTensor(object):
# SparseTensor.remove_diag = remove_diag
# SparseTensor.remove_diag = remove_diag
# SparseTensor.set_diag = set_diag
# SparseTensor.set_diag = set_diag
# SparseTensor.matmul = matmul
# SparseTensor.matmul = matmul
# SparseTensor.add = add
# SparseTensor.add_ = add_
# SparseTensor.add_nnz = add_nnz
# SparseTensor.add_nnz_ = add_nnz_
# SparseTensor.mul = mul
# SparseTensor.mul_ = mul_
# SparseTensor.mul_nnz = mul_nnz
# SparseTensor.mul_nnz_ = mul_nnz_
# Python Bindings #############################################################
# Python Bindings #############################################################
...
...
torch_sparse/transpose.py
View file @
4a68dd60
...
@@ -6,7 +6,7 @@ from torch_sparse.tensor import SparseTensor
...
@@ -6,7 +6,7 @@ from torch_sparse.tensor import SparseTensor
@
torch
.
jit
.
script
@
torch
.
jit
.
script
def
t
(
src
:
SparseTensor
):
def
t
(
src
:
SparseTensor
)
->
SparseTensor
:
csr2csc
=
src
.
storage
.
csr2csc
()
csr2csc
=
src
.
storage
.
csr2csc
()
row
,
col
,
value
=
src
.
coo
()
row
,
col
,
value
=
src
.
coo
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment