Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
e6f5c3f0
Commit
e6f5c3f0
authored
Feb 19, 2020
by
rusty1s
Browse files
removed jit in external functions due to wrong caching behaviour
parent
3c259af5
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
53 additions
and
169 deletions
+53
-169
torch_sparse/add.py
torch_sparse/add.py
+0
-4
torch_sparse/cat.py
torch_sparse/cat.py
+0
-2
torch_sparse/diag.py
torch_sparse/diag.py
+10
-31
torch_sparse/index_select.py
torch_sparse/index_select.py
+0
-2
torch_sparse/masked_select.py
torch_sparse/masked_select.py
+0
-2
torch_sparse/matmul.py
torch_sparse/matmul.py
+3
-18
torch_sparse/metis.py
torch_sparse/metis.py
+0
-1
torch_sparse/mul.py
torch_sparse/mul.py
+0
-4
torch_sparse/narrow.py
torch_sparse/narrow.py
+0
-2
torch_sparse/permute.py
torch_sparse/permute.py
+0
-1
torch_sparse/reduce.py
torch_sparse/reduce.py
+0
-5
torch_sparse/select.py
torch_sparse/select.py
+0
-2
torch_sparse/storage.py
torch_sparse/storage.py
+40
-94
torch_sparse/transpose.py
torch_sparse/transpose.py
+0
-1
No files found.
torch_sparse/add.py
View file @
e6f5c3f0
...
@@ -5,7 +5,6 @@ from torch_scatter import gather_csr
...
@@ -5,7 +5,6 @@ from torch_scatter import gather_csr
from
torch_sparse.tensor
import
SparseTensor
from
torch_sparse.tensor
import
SparseTensor
@
torch
.
jit
.
script
def
add
(
src
:
SparseTensor
,
other
:
torch
.
Tensor
)
->
SparseTensor
:
def
add
(
src
:
SparseTensor
,
other
:
torch
.
Tensor
)
->
SparseTensor
:
rowptr
,
col
,
value
=
src
.
csr
()
rowptr
,
col
,
value
=
src
.
csr
()
if
other
.
size
(
0
)
==
src
.
size
(
0
)
and
other
.
size
(
1
)
==
1
:
# Row-wise...
if
other
.
size
(
0
)
==
src
.
size
(
0
)
and
other
.
size
(
1
)
==
1
:
# Row-wise...
...
@@ -24,7 +23,6 @@ def add(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
...
@@ -24,7 +23,6 @@ def add(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
return
src
.
set_value
(
value
,
layout
=
'coo'
)
return
src
.
set_value
(
value
,
layout
=
'coo'
)
@
torch
.
jit
.
script
def
add_
(
src
:
SparseTensor
,
other
:
torch
.
Tensor
)
->
SparseTensor
:
def
add_
(
src
:
SparseTensor
,
other
:
torch
.
Tensor
)
->
SparseTensor
:
rowptr
,
col
,
value
=
src
.
csr
()
rowptr
,
col
,
value
=
src
.
csr
()
if
other
.
size
(
0
)
==
src
.
size
(
0
)
and
other
.
size
(
1
)
==
1
:
# Row-wise...
if
other
.
size
(
0
)
==
src
.
size
(
0
)
and
other
.
size
(
1
)
==
1
:
# Row-wise...
...
@@ -44,7 +42,6 @@ def add_(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
...
@@ -44,7 +42,6 @@ def add_(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
return
src
.
set_value_
(
value
,
layout
=
'coo'
)
return
src
.
set_value_
(
value
,
layout
=
'coo'
)
@
torch
.
jit
.
script
def
add_nnz
(
src
:
SparseTensor
,
other
:
torch
.
Tensor
,
def
add_nnz
(
src
:
SparseTensor
,
other
:
torch
.
Tensor
,
layout
:
Optional
[
str
]
=
None
)
->
SparseTensor
:
layout
:
Optional
[
str
]
=
None
)
->
SparseTensor
:
value
=
src
.
storage
.
value
()
value
=
src
.
storage
.
value
()
...
@@ -55,7 +52,6 @@ def add_nnz(src: SparseTensor, other: torch.Tensor,
...
@@ -55,7 +52,6 @@ def add_nnz(src: SparseTensor, other: torch.Tensor,
return
src
.
set_value
(
value
,
layout
=
layout
)
return
src
.
set_value
(
value
,
layout
=
layout
)
@
torch
.
jit
.
script
def
add_nnz_
(
src
:
SparseTensor
,
other
:
torch
.
Tensor
,
def
add_nnz_
(
src
:
SparseTensor
,
other
:
torch
.
Tensor
,
layout
:
Optional
[
str
]
=
None
)
->
SparseTensor
:
layout
:
Optional
[
str
]
=
None
)
->
SparseTensor
:
value
=
src
.
storage
.
value
()
value
=
src
.
storage
.
value
()
...
...
torch_sparse/cat.py
View file @
e6f5c3f0
...
@@ -5,7 +5,6 @@ from torch_sparse.storage import SparseStorage
...
@@ -5,7 +5,6 @@ from torch_sparse.storage import SparseStorage
from
torch_sparse.tensor
import
SparseTensor
from
torch_sparse.tensor
import
SparseTensor
@
torch
.
jit
.
script
def
cat
(
tensors
:
List
[
SparseTensor
],
dim
:
int
)
->
SparseTensor
:
def
cat
(
tensors
:
List
[
SparseTensor
],
dim
:
int
)
->
SparseTensor
:
assert
len
(
tensors
)
>
0
assert
len
(
tensors
)
>
0
if
dim
<
0
:
if
dim
<
0
:
...
@@ -142,7 +141,6 @@ def cat(tensors: List[SparseTensor], dim: int) -> SparseTensor:
...
@@ -142,7 +141,6 @@ def cat(tensors: List[SparseTensor], dim: int) -> SparseTensor:
'[{-tensors[0].dim()}, {tensors[0].dim() - 1}], but got {dim}.'
)
'[{-tensors[0].dim()}, {tensors[0].dim() - 1}], but got {dim}.'
)
@
torch
.
jit
.
script
def
cat_diag
(
tensors
:
List
[
SparseTensor
])
->
SparseTensor
:
def
cat_diag
(
tensors
:
List
[
SparseTensor
])
->
SparseTensor
:
assert
len
(
tensors
)
>
0
assert
len
(
tensors
)
>
0
...
...
torch_sparse/diag.py
View file @
e6f5c3f0
...
@@ -5,7 +5,6 @@ from torch_sparse.storage import SparseStorage
...
@@ -5,7 +5,6 @@ from torch_sparse.storage import SparseStorage
from
torch_sparse.tensor
import
SparseTensor
from
torch_sparse.tensor
import
SparseTensor
@
torch
.
jit
.
script
def
remove_diag
(
src
:
SparseTensor
,
k
:
int
=
0
)
->
SparseTensor
:
def
remove_diag
(
src
:
SparseTensor
,
k
:
int
=
0
)
->
SparseTensor
:
row
,
col
,
value
=
src
.
coo
()
row
,
col
,
value
=
src
.
coo
()
inv_mask
=
row
!=
col
if
k
==
0
else
row
!=
(
col
-
k
)
inv_mask
=
row
!=
col
if
k
==
0
else
row
!=
(
col
-
k
)
...
@@ -25,24 +24,14 @@ def remove_diag(src: SparseTensor, k: int = 0) -> SparseTensor:
...
@@ -25,24 +24,14 @@ def remove_diag(src: SparseTensor, k: int = 0) -> SparseTensor:
colcount
=
colcount
.
clone
()
colcount
=
colcount
.
clone
()
colcount
[
col
[
mask
]]
-=
1
colcount
[
col
[
mask
]]
-=
1
storage
=
SparseStorage
(
storage
=
SparseStorage
(
row
=
new_row
,
rowptr
=
None
,
col
=
new_col
,
value
=
value
,
row
=
new_row
,
sparse_sizes
=
src
.
sparse_sizes
(),
rowcount
=
rowcount
,
rowptr
=
None
,
colptr
=
None
,
colcount
=
colcount
,
csr2csc
=
None
,
col
=
new_col
,
csc2csr
=
None
,
is_sorted
=
True
)
value
=
value
,
sparse_sizes
=
src
.
sparse_sizes
(),
rowcount
=
rowcount
,
colptr
=
None
,
colcount
=
colcount
,
csr2csc
=
None
,
csc2csr
=
None
,
is_sorted
=
True
)
return
src
.
from_storage
(
storage
)
return
src
.
from_storage
(
storage
)
@
torch
.
jit
.
script
def
set_diag
(
src
:
SparseTensor
,
values
:
Optional
[
torch
.
Tensor
]
=
None
,
def
set_diag
(
src
:
SparseTensor
,
values
:
Optional
[
torch
.
Tensor
]
=
None
,
k
:
int
=
0
)
->
SparseTensor
:
k
:
int
=
0
)
->
SparseTensor
:
src
=
remove_diag
(
src
,
k
=
k
)
src
=
remove_diag
(
src
,
k
=
k
)
row
,
col
,
value
=
src
.
coo
()
row
,
col
,
value
=
src
.
coo
()
...
@@ -69,8 +58,7 @@ def set_diag(src: SparseTensor,
...
@@ -69,8 +58,7 @@ def set_diag(src: SparseTensor,
if
values
is
not
None
:
if
values
is
not
None
:
new_value
[
inv_mask
]
=
values
new_value
[
inv_mask
]
=
values
else
:
else
:
new_value
[
inv_mask
]
=
torch
.
ones
((
num_diag
,
),
new_value
[
inv_mask
]
=
torch
.
ones
((
num_diag
,
),
dtype
=
value
.
dtype
,
dtype
=
value
.
dtype
,
device
=
value
.
device
)
device
=
value
.
device
)
rowcount
=
src
.
storage
.
_rowcount
rowcount
=
src
.
storage
.
_rowcount
...
@@ -83,22 +71,13 @@ def set_diag(src: SparseTensor,
...
@@ -83,22 +71,13 @@ def set_diag(src: SparseTensor,
colcount
=
colcount
.
clone
()
colcount
=
colcount
.
clone
()
colcount
[
start
+
k
:
start
+
num_diag
+
k
]
+=
1
colcount
[
start
+
k
:
start
+
num_diag
+
k
]
+=
1
storage
=
SparseStorage
(
storage
=
SparseStorage
(
row
=
new_row
,
rowptr
=
None
,
col
=
new_col
,
row
=
new_row
,
value
=
new_value
,
sparse_sizes
=
src
.
sparse_sizes
(),
rowptr
=
None
,
rowcount
=
rowcount
,
colptr
=
None
,
colcount
=
colcount
,
col
=
new_col
,
csr2csc
=
None
,
csc2csr
=
None
,
is_sorted
=
True
)
value
=
new_value
,
sparse_sizes
=
src
.
sparse_sizes
(),
rowcount
=
rowcount
,
colptr
=
None
,
colcount
=
colcount
,
csr2csc
=
None
,
csc2csr
=
None
,
is_sorted
=
True
)
return
src
.
from_storage
(
storage
)
return
src
.
from_storage
(
storage
)
@
torch
.
jit
.
script
def
fill_diag
(
src
:
SparseTensor
,
fill_value
:
int
,
k
:
int
=
0
)
->
SparseTensor
:
def
fill_diag
(
src
:
SparseTensor
,
fill_value
:
int
,
k
:
int
=
0
)
->
SparseTensor
:
num_diag
=
min
(
src
.
sparse_size
(
0
),
src
.
sparse_size
(
1
)
-
k
)
num_diag
=
min
(
src
.
sparse_size
(
0
),
src
.
sparse_size
(
1
)
-
k
)
if
k
<
0
:
if
k
<
0
:
...
...
torch_sparse/index_select.py
View file @
e6f5c3f0
...
@@ -6,7 +6,6 @@ from torch_sparse.storage import SparseStorage, get_layout
...
@@ -6,7 +6,6 @@ from torch_sparse.storage import SparseStorage, get_layout
from
torch_sparse.tensor
import
SparseTensor
from
torch_sparse.tensor
import
SparseTensor
@
torch
.
jit
.
script
def
index_select
(
src
:
SparseTensor
,
dim
:
int
,
def
index_select
(
src
:
SparseTensor
,
dim
:
int
,
idx
:
torch
.
Tensor
)
->
SparseTensor
:
idx
:
torch
.
Tensor
)
->
SparseTensor
:
dim
=
src
.
dim
()
+
dim
if
dim
<
0
else
dim
dim
=
src
.
dim
()
+
dim
if
dim
<
0
else
dim
...
@@ -79,7 +78,6 @@ def index_select(src: SparseTensor, dim: int,
...
@@ -79,7 +78,6 @@ def index_select(src: SparseTensor, dim: int,
raise
ValueError
raise
ValueError
@
torch
.
jit
.
script
def
index_select_nnz
(
src
:
SparseTensor
,
idx
:
torch
.
Tensor
,
def
index_select_nnz
(
src
:
SparseTensor
,
idx
:
torch
.
Tensor
,
layout
:
Optional
[
str
]
=
None
)
->
SparseTensor
:
layout
:
Optional
[
str
]
=
None
)
->
SparseTensor
:
assert
idx
.
dim
()
==
1
assert
idx
.
dim
()
==
1
...
...
torch_sparse/masked_select.py
View file @
e6f5c3f0
...
@@ -5,7 +5,6 @@ from torch_sparse.storage import SparseStorage, get_layout
...
@@ -5,7 +5,6 @@ from torch_sparse.storage import SparseStorage, get_layout
from
torch_sparse.tensor
import
SparseTensor
from
torch_sparse.tensor
import
SparseTensor
@
torch
.
jit
.
script
def
masked_select
(
src
:
SparseTensor
,
dim
:
int
,
def
masked_select
(
src
:
SparseTensor
,
dim
:
int
,
mask
:
torch
.
Tensor
)
->
SparseTensor
:
mask
:
torch
.
Tensor
)
->
SparseTensor
:
dim
=
src
.
dim
()
+
dim
if
dim
<
0
else
dim
dim
=
src
.
dim
()
+
dim
if
dim
<
0
else
dim
...
@@ -73,7 +72,6 @@ def masked_select(src: SparseTensor, dim: int,
...
@@ -73,7 +72,6 @@ def masked_select(src: SparseTensor, dim: int,
raise
ValueError
raise
ValueError
@
torch
.
jit
.
script
def
masked_select_nnz
(
src
:
SparseTensor
,
mask
:
torch
.
Tensor
,
def
masked_select_nnz
(
src
:
SparseTensor
,
mask
:
torch
.
Tensor
,
layout
:
Optional
[
str
]
=
None
)
->
SparseTensor
:
layout
:
Optional
[
str
]
=
None
)
->
SparseTensor
:
assert
mask
.
dim
()
==
1
assert
mask
.
dim
()
==
1
...
...
torch_sparse/matmul.py
View file @
e6f5c3f0
...
@@ -4,7 +4,6 @@ import torch
...
@@ -4,7 +4,6 @@ import torch
from
torch_sparse.tensor
import
SparseTensor
from
torch_sparse.tensor
import
SparseTensor
@
torch
.
jit
.
script
def
spmm_sum
(
src
:
SparseTensor
,
other
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
spmm_sum
(
src
:
SparseTensor
,
other
:
torch
.
Tensor
)
->
torch
.
Tensor
:
rowptr
,
col
,
value
=
src
.
csr
()
rowptr
,
col
,
value
=
src
.
csr
()
...
@@ -24,12 +23,10 @@ def spmm_sum(src: SparseTensor, other: torch.Tensor) -> torch.Tensor:
...
@@ -24,12 +23,10 @@ def spmm_sum(src: SparseTensor, other: torch.Tensor) -> torch.Tensor:
csr2csc
,
other
)
csr2csc
,
other
)
@
torch
.
jit
.
script
def
spmm_add
(
src
:
SparseTensor
,
other
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
spmm_add
(
src
:
SparseTensor
,
other
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
spmm_sum
(
src
,
other
)
return
spmm_sum
(
src
,
other
)
@
torch
.
jit
.
script
def
spmm_mean
(
src
:
SparseTensor
,
other
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
spmm_mean
(
src
:
SparseTensor
,
other
:
torch
.
Tensor
)
->
torch
.
Tensor
:
rowptr
,
col
,
value
=
src
.
csr
()
rowptr
,
col
,
value
=
src
.
csr
()
...
@@ -51,21 +48,18 @@ def spmm_mean(src: SparseTensor, other: torch.Tensor) -> torch.Tensor:
...
@@ -51,21 +48,18 @@ def spmm_mean(src: SparseTensor, other: torch.Tensor) -> torch.Tensor:
colptr
,
csr2csc
,
other
)
colptr
,
csr2csc
,
other
)
@
torch
.
jit
.
script
def
spmm_min
(
src
:
SparseTensor
,
def
spmm_min
(
src
:
SparseTensor
,
other
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
other
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
rowptr
,
col
,
value
=
src
.
csr
()
rowptr
,
col
,
value
=
src
.
csr
()
return
torch
.
ops
.
torch_sparse
.
spmm_min
(
rowptr
,
col
,
value
,
other
)
return
torch
.
ops
.
torch_sparse
.
spmm_min
(
rowptr
,
col
,
value
,
other
)
@
torch
.
jit
.
script
def
spmm_max
(
src
:
SparseTensor
,
def
spmm_max
(
src
:
SparseTensor
,
other
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
other
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
rowptr
,
col
,
value
=
src
.
csr
()
rowptr
,
col
,
value
=
src
.
csr
()
return
torch
.
ops
.
torch_sparse
.
spmm_max
(
rowptr
,
col
,
value
,
other
)
return
torch
.
ops
.
torch_sparse
.
spmm_max
(
rowptr
,
col
,
value
,
other
)
@
torch
.
jit
.
script
def
spmm
(
src
:
SparseTensor
,
other
:
torch
.
Tensor
,
def
spmm
(
src
:
SparseTensor
,
other
:
torch
.
Tensor
,
reduce
:
str
=
"sum"
)
->
torch
.
Tensor
:
reduce
:
str
=
"sum"
)
->
torch
.
Tensor
:
if
reduce
==
'sum'
or
reduce
==
'add'
:
if
reduce
==
'sum'
or
reduce
==
'add'
:
...
@@ -80,7 +74,6 @@ def spmm(src: SparseTensor, other: torch.Tensor,
...
@@ -80,7 +74,6 @@ def spmm(src: SparseTensor, other: torch.Tensor,
raise
ValueError
raise
ValueError
@
torch
.
jit
.
script
def
spspmm_sum
(
src
:
SparseTensor
,
other
:
SparseTensor
)
->
SparseTensor
:
def
spspmm_sum
(
src
:
SparseTensor
,
other
:
SparseTensor
)
->
SparseTensor
:
assert
src
.
sparse_size
(
1
)
==
other
.
sparse_size
(
0
)
assert
src
.
sparse_size
(
1
)
==
other
.
sparse_size
(
0
)
rowptrA
,
colA
,
valueA
=
src
.
csr
()
rowptrA
,
colA
,
valueA
=
src
.
csr
()
...
@@ -88,21 +81,14 @@ def spspmm_sum(src: SparseTensor, other: SparseTensor) -> SparseTensor:
...
@@ -88,21 +81,14 @@ def spspmm_sum(src: SparseTensor, other: SparseTensor) -> SparseTensor:
M
,
K
=
src
.
sparse_size
(
0
),
other
.
sparse_size
(
1
)
M
,
K
=
src
.
sparse_size
(
0
),
other
.
sparse_size
(
1
)
rowptrC
,
colC
,
valueC
=
torch
.
ops
.
torch_sparse
.
spspmm_sum
(
rowptrC
,
colC
,
valueC
=
torch
.
ops
.
torch_sparse
.
spspmm_sum
(
rowptrA
,
colA
,
valueA
,
rowptrB
,
colB
,
valueB
,
K
)
rowptrA
,
colA
,
valueA
,
rowptrB
,
colB
,
valueB
,
K
)
return
SparseTensor
(
return
SparseTensor
(
row
=
None
,
rowptr
=
rowptrC
,
col
=
colC
,
value
=
valueC
,
row
=
None
,
sparse_sizes
=
torch
.
Size
([
M
,
K
]),
is_sorted
=
True
)
rowptr
=
rowptrC
,
col
=
colC
,
value
=
valueC
,
sparse_sizes
=
torch
.
Size
([
M
,
K
]),
is_sorted
=
True
)
@
torch
.
jit
.
script
def
spspmm_add
(
src
:
SparseTensor
,
other
:
SparseTensor
)
->
SparseTensor
:
def
spspmm_add
(
src
:
SparseTensor
,
other
:
SparseTensor
)
->
SparseTensor
:
return
spspmm_sum
(
src
,
other
)
return
spspmm_sum
(
src
,
other
)
@
torch
.
jit
.
script
def
spspmm
(
src
:
SparseTensor
,
other
:
SparseTensor
,
def
spspmm
(
src
:
SparseTensor
,
other
:
SparseTensor
,
reduce
:
str
=
"sum"
)
->
SparseTensor
:
reduce
:
str
=
"sum"
)
->
SparseTensor
:
if
reduce
==
'sum'
or
reduce
==
'add'
:
if
reduce
==
'sum'
or
reduce
==
'add'
:
...
@@ -113,8 +99,7 @@ def spspmm(src: SparseTensor, other: SparseTensor,
...
@@ -113,8 +99,7 @@ def spspmm(src: SparseTensor, other: SparseTensor,
raise
ValueError
raise
ValueError
def
matmul
(
src
:
SparseTensor
,
def
matmul
(
src
:
SparseTensor
,
other
:
Union
[
torch
.
Tensor
,
SparseTensor
],
other
:
Union
[
torch
.
Tensor
,
SparseTensor
],
reduce
:
str
=
"sum"
):
reduce
:
str
=
"sum"
):
if
torch
.
is_tensor
(
other
):
if
torch
.
is_tensor
(
other
):
return
spmm
(
src
,
other
,
reduce
)
return
spmm
(
src
,
other
,
reduce
)
...
...
torch_sparse/metis.py
View file @
e6f5c3f0
...
@@ -5,7 +5,6 @@ from torch_sparse.tensor import SparseTensor
...
@@ -5,7 +5,6 @@ from torch_sparse.tensor import SparseTensor
from
torch_sparse.permute
import
permute
from
torch_sparse.permute
import
permute
@
torch
.
jit
.
script
def
partition_kway
(
def
partition_kway
(
src
:
SparseTensor
,
src
:
SparseTensor
,
num_parts
:
int
)
->
Tuple
[
SparseTensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
num_parts
:
int
)
->
Tuple
[
SparseTensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
...
...
torch_sparse/mul.py
View file @
e6f5c3f0
...
@@ -5,7 +5,6 @@ from torch_scatter import gather_csr
...
@@ -5,7 +5,6 @@ from torch_scatter import gather_csr
from
torch_sparse.tensor
import
SparseTensor
from
torch_sparse.tensor
import
SparseTensor
@
torch
.
jit
.
script
def
mul
(
src
:
SparseTensor
,
other
:
torch
.
Tensor
)
->
SparseTensor
:
def
mul
(
src
:
SparseTensor
,
other
:
torch
.
Tensor
)
->
SparseTensor
:
rowptr
,
col
,
value
=
src
.
csr
()
rowptr
,
col
,
value
=
src
.
csr
()
if
other
.
size
(
0
)
==
src
.
size
(
0
)
and
other
.
size
(
1
)
==
1
:
# Row-wise...
if
other
.
size
(
0
)
==
src
.
size
(
0
)
and
other
.
size
(
1
)
==
1
:
# Row-wise...
...
@@ -25,7 +24,6 @@ def mul(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
...
@@ -25,7 +24,6 @@ def mul(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
return
src
.
set_value
(
value
,
layout
=
'coo'
)
return
src
.
set_value
(
value
,
layout
=
'coo'
)
@
torch
.
jit
.
script
def
mul_
(
src
:
SparseTensor
,
other
:
torch
.
Tensor
)
->
SparseTensor
:
def
mul_
(
src
:
SparseTensor
,
other
:
torch
.
Tensor
)
->
SparseTensor
:
rowptr
,
col
,
value
=
src
.
csr
()
rowptr
,
col
,
value
=
src
.
csr
()
if
other
.
size
(
0
)
==
src
.
size
(
0
)
and
other
.
size
(
1
)
==
1
:
# Row-wise...
if
other
.
size
(
0
)
==
src
.
size
(
0
)
and
other
.
size
(
1
)
==
1
:
# Row-wise...
...
@@ -45,7 +43,6 @@ def mul_(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
...
@@ -45,7 +43,6 @@ def mul_(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
return
src
.
set_value_
(
value
,
layout
=
'coo'
)
return
src
.
set_value_
(
value
,
layout
=
'coo'
)
@
torch
.
jit
.
script
def
mul_nnz
(
src
:
SparseTensor
,
other
:
torch
.
Tensor
,
def
mul_nnz
(
src
:
SparseTensor
,
other
:
torch
.
Tensor
,
layout
:
Optional
[
str
]
=
None
)
->
SparseTensor
:
layout
:
Optional
[
str
]
=
None
)
->
SparseTensor
:
value
=
src
.
storage
.
value
()
value
=
src
.
storage
.
value
()
...
@@ -56,7 +53,6 @@ def mul_nnz(src: SparseTensor, other: torch.Tensor,
...
@@ -56,7 +53,6 @@ def mul_nnz(src: SparseTensor, other: torch.Tensor,
return
src
.
set_value
(
value
,
layout
=
layout
)
return
src
.
set_value
(
value
,
layout
=
layout
)
@
torch
.
jit
.
script
def
mul_nnz_
(
src
:
SparseTensor
,
other
:
torch
.
Tensor
,
def
mul_nnz_
(
src
:
SparseTensor
,
other
:
torch
.
Tensor
,
layout
:
Optional
[
str
]
=
None
)
->
SparseTensor
:
layout
:
Optional
[
str
]
=
None
)
->
SparseTensor
:
value
=
src
.
storage
.
value
()
value
=
src
.
storage
.
value
()
...
...
torch_sparse/narrow.py
View file @
e6f5c3f0
...
@@ -5,7 +5,6 @@ from torch_sparse.storage import SparseStorage
...
@@ -5,7 +5,6 @@ from torch_sparse.storage import SparseStorage
from
torch_sparse.tensor
import
SparseTensor
from
torch_sparse.tensor
import
SparseTensor
@
torch
.
jit
.
script
def
narrow
(
src
:
SparseTensor
,
dim
:
int
,
start
:
int
,
def
narrow
(
src
:
SparseTensor
,
dim
:
int
,
start
:
int
,
length
:
int
)
->
SparseTensor
:
length
:
int
)
->
SparseTensor
:
if
dim
<
0
:
if
dim
<
0
:
...
@@ -80,7 +79,6 @@ def narrow(src: SparseTensor, dim: int, start: int,
...
@@ -80,7 +79,6 @@ def narrow(src: SparseTensor, dim: int, start: int,
raise
ValueError
raise
ValueError
@
torch
.
jit
.
script
def
__narrow_diag__
(
src
:
SparseTensor
,
start
:
Tuple
[
int
,
int
],
def
__narrow_diag__
(
src
:
SparseTensor
,
start
:
Tuple
[
int
,
int
],
length
:
Tuple
[
int
,
int
])
->
SparseTensor
:
length
:
Tuple
[
int
,
int
])
->
SparseTensor
:
# This function builds the inverse operation of `cat_diag` and should hence
# This function builds the inverse operation of `cat_diag` and should hence
...
...
torch_sparse/permute.py
View file @
e6f5c3f0
...
@@ -3,7 +3,6 @@ from torch_sparse.storage import SparseStorage
...
@@ -3,7 +3,6 @@ from torch_sparse.storage import SparseStorage
from
torch_sparse.tensor
import
SparseTensor
from
torch_sparse.tensor
import
SparseTensor
@
torch
.
jit
.
script
def
permute
(
src
:
SparseTensor
,
perm
:
torch
.
Tensor
)
->
SparseTensor
:
def
permute
(
src
:
SparseTensor
,
perm
:
torch
.
Tensor
)
->
SparseTensor
:
assert
src
.
is_symmetric
()
assert
src
.
is_symmetric
()
...
...
torch_sparse/reduce.py
View file @
e6f5c3f0
...
@@ -5,7 +5,6 @@ from torch_scatter import scatter, segment_csr
...
@@ -5,7 +5,6 @@ from torch_scatter import scatter, segment_csr
from
torch_sparse.tensor
import
SparseTensor
from
torch_sparse.tensor
import
SparseTensor
@
torch
.
jit
.
script
def
reduction
(
src
:
SparseTensor
,
dim
:
Optional
[
int
]
=
None
,
def
reduction
(
src
:
SparseTensor
,
dim
:
Optional
[
int
]
=
None
,
reduce
:
str
=
'sum'
)
->
torch
.
Tensor
:
reduce
:
str
=
'sum'
)
->
torch
.
Tensor
:
value
=
src
.
storage
.
value
()
value
=
src
.
storage
.
value
()
...
@@ -68,22 +67,18 @@ def reduction(src: SparseTensor, dim: Optional[int] = None,
...
@@ -68,22 +67,18 @@ def reduction(src: SparseTensor, dim: Optional[int] = None,
raise
ValueError
raise
ValueError
@
torch
.
jit
.
script
def
sum
(
src
:
SparseTensor
,
dim
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
def
sum
(
src
:
SparseTensor
,
dim
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
return
reduction
(
src
,
dim
,
reduce
=
'sum'
)
return
reduction
(
src
,
dim
,
reduce
=
'sum'
)
@
torch
.
jit
.
script
def
mean
(
src
:
SparseTensor
,
dim
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
def
mean
(
src
:
SparseTensor
,
dim
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
return
reduction
(
src
,
dim
,
reduce
=
'mean'
)
return
reduction
(
src
,
dim
,
reduce
=
'mean'
)
@
torch
.
jit
.
script
def
min
(
src
:
SparseTensor
,
dim
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
def
min
(
src
:
SparseTensor
,
dim
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
return
reduction
(
src
,
dim
,
reduce
=
'min'
)
return
reduction
(
src
,
dim
,
reduce
=
'min'
)
@
torch
.
jit
.
script
def
max
(
src
:
SparseTensor
,
dim
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
def
max
(
src
:
SparseTensor
,
dim
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
return
reduction
(
src
,
dim
,
reduce
=
'max'
)
return
reduction
(
src
,
dim
,
reduce
=
'max'
)
...
...
torch_sparse/select.py
View file @
e6f5c3f0
import
torch
from
torch_sparse.tensor
import
SparseTensor
from
torch_sparse.tensor
import
SparseTensor
from
torch_sparse.narrow
import
narrow
from
torch_sparse.narrow
import
narrow
@
torch
.
jit
.
script
def
select
(
src
:
SparseTensor
,
dim
:
int
,
idx
:
int
)
->
SparseTensor
:
def
select
(
src
:
SparseTensor
,
dim
:
int
,
idx
:
int
)
->
SparseTensor
:
return
narrow
(
src
,
dim
,
start
=
idx
,
length
=
1
)
return
narrow
(
src
,
dim
,
start
=
idx
,
length
=
1
)
...
...
torch_sparse/storage.py
View file @
e6f5c3f0
...
@@ -30,8 +30,7 @@ class SparseStorage(object):
...
@@ -30,8 +30,7 @@ class SparseStorage(object):
_csr2csc
:
Optional
[
torch
.
Tensor
]
_csr2csc
:
Optional
[
torch
.
Tensor
]
_csc2csr
:
Optional
[
torch
.
Tensor
]
_csc2csr
:
Optional
[
torch
.
Tensor
]
def
__init__
(
self
,
def
__init__
(
self
,
row
:
Optional
[
torch
.
Tensor
]
=
None
,
row
:
Optional
[
torch
.
Tensor
]
=
None
,
rowptr
:
Optional
[
torch
.
Tensor
]
=
None
,
rowptr
:
Optional
[
torch
.
Tensor
]
=
None
,
col
:
Optional
[
torch
.
Tensor
]
=
None
,
col
:
Optional
[
torch
.
Tensor
]
=
None
,
value
:
Optional
[
torch
.
Tensor
]
=
None
,
value
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -192,8 +191,7 @@ class SparseStorage(object):
...
@@ -192,8 +191,7 @@ class SparseStorage(object):
def
value
(
self
)
->
Optional
[
torch
.
Tensor
]:
def
value
(
self
)
->
Optional
[
torch
.
Tensor
]:
return
self
.
_value
return
self
.
_value
def
set_value_
(
self
,
def
set_value_
(
self
,
value
:
Optional
[
torch
.
Tensor
],
value
:
Optional
[
torch
.
Tensor
],
layout
:
Optional
[
str
]
=
None
):
layout
:
Optional
[
str
]
=
None
):
if
value
is
not
None
:
if
value
is
not
None
:
if
get_layout
(
layout
)
==
'csc'
:
if
get_layout
(
layout
)
==
'csc'
:
...
@@ -205,8 +203,7 @@ class SparseStorage(object):
...
@@ -205,8 +203,7 @@ class SparseStorage(object):
self
.
_value
=
value
self
.
_value
=
value
return
self
return
self
def
set_value
(
self
,
def
set_value
(
self
,
value
:
Optional
[
torch
.
Tensor
],
value
:
Optional
[
torch
.
Tensor
],
layout
:
Optional
[
str
]
=
None
):
layout
:
Optional
[
str
]
=
None
):
if
value
is
not
None
:
if
value
is
not
None
:
if
get_layout
(
layout
)
==
'csc'
:
if
get_layout
(
layout
)
==
'csc'
:
...
@@ -215,18 +212,11 @@ class SparseStorage(object):
...
@@ -215,18 +212,11 @@ class SparseStorage(object):
assert
value
.
device
==
self
.
_col
.
device
assert
value
.
device
==
self
.
_col
.
device
assert
value
.
size
(
0
)
==
self
.
_col
.
numel
()
assert
value
.
size
(
0
)
==
self
.
_col
.
numel
()
return
SparseStorage
(
return
SparseStorage
(
row
=
self
.
_row
,
rowptr
=
self
.
_rowptr
,
col
=
self
.
_col
,
row
=
self
.
_row
,
value
=
value
,
sparse_sizes
=
self
.
_sparse_sizes
,
rowptr
=
self
.
_rowptr
,
rowcount
=
self
.
_rowcount
,
colptr
=
self
.
_colptr
,
col
=
self
.
_col
,
colcount
=
self
.
_colcount
,
csr2csc
=
self
.
_csr2csc
,
value
=
value
,
csc2csr
=
self
.
_csc2csr
,
is_sorted
=
True
)
sparse_sizes
=
self
.
_sparse_sizes
,
rowcount
=
self
.
_rowcount
,
colptr
=
self
.
_colptr
,
colcount
=
self
.
_colcount
,
csr2csc
=
self
.
_csr2csc
,
csc2csr
=
self
.
_csc2csr
,
is_sorted
=
True
)
def
sparse_sizes
(
self
)
->
List
[
int
]:
def
sparse_sizes
(
self
)
->
List
[
int
]:
return
self
.
_sparse_sizes
return
self
.
_sparse_sizes
...
@@ -264,18 +254,11 @@ class SparseStorage(object):
...
@@ -264,18 +254,11 @@ class SparseStorage(object):
if
colcount
is
not
None
:
if
colcount
is
not
None
:
colcount
=
colcount
[:
-
diff_1
]
colcount
=
colcount
[:
-
diff_1
]
return
SparseStorage
(
return
SparseStorage
(
row
=
self
.
_row
,
rowptr
=
rowptr
,
col
=
self
.
_col
,
row
=
self
.
_row
,
value
=
self
.
_value
,
sparse_sizes
=
sparse_sizes
,
rowptr
=
rowptr
,
rowcount
=
rowcount
,
colptr
=
colptr
,
col
=
self
.
_col
,
colcount
=
colcount
,
csr2csc
=
self
.
_csr2csc
,
value
=
self
.
_value
,
csc2csr
=
self
.
_csc2csr
,
is_sorted
=
True
)
sparse_sizes
=
sparse_sizes
,
rowcount
=
rowcount
,
colptr
=
colptr
,
colcount
=
colcount
,
csr2csc
=
self
.
_csr2csc
,
csc2csr
=
self
.
_csc2csr
,
is_sorted
=
True
)
def
has_rowcount
(
self
)
->
bool
:
def
has_rowcount
(
self
)
->
bool
:
return
self
.
_rowcount
is
not
None
return
self
.
_rowcount
is
not
None
...
@@ -320,10 +303,8 @@ class SparseStorage(object):
...
@@ -320,10 +303,8 @@ class SparseStorage(object):
if
colptr
is
not
None
:
if
colptr
is
not
None
:
colcount
=
colptr
[
1
:]
-
colptr
[:
-
1
]
colcount
=
colptr
[
1
:]
-
colptr
[:
-
1
]
else
:
else
:
colcount
=
scatter_add
(
colcount
=
scatter_add
(
torch
.
ones_like
(
self
.
_col
),
self
.
_col
,
torch
.
ones_like
(
self
.
_col
),
dim_size
=
self
.
_sparse_sizes
[
1
])
self
.
_col
,
dim_size
=
self
.
_sparse_sizes
[
1
])
self
.
_colcount
=
colcount
self
.
_colcount
=
colcount
return
colcount
return
colcount
...
@@ -375,18 +356,10 @@ class SparseStorage(object):
...
@@ -375,18 +356,10 @@ class SparseStorage(object):
value
=
segment_csr
(
value
,
ptr
,
reduce
=
reduce
)
value
=
segment_csr
(
value
,
ptr
,
reduce
=
reduce
)
value
=
value
[
0
]
if
isinstance
(
value
,
tuple
)
else
value
value
=
value
[
0
]
if
isinstance
(
value
,
tuple
)
else
value
return
SparseStorage
(
return
SparseStorage
(
row
=
row
,
rowptr
=
None
,
col
=
col
,
value
=
value
,
row
=
row
,
sparse_sizes
=
self
.
_sparse_sizes
,
rowcount
=
None
,
rowptr
=
None
,
colptr
=
None
,
colcount
=
None
,
csr2csc
=
None
,
col
=
col
,
csc2csr
=
None
,
is_sorted
=
True
)
value
=
value
,
sparse_sizes
=
self
.
_sparse_sizes
,
rowcount
=
None
,
colptr
=
None
,
colcount
=
None
,
csr2csc
=
None
,
csc2csr
=
None
,
is_sorted
=
True
)
def
fill_cache_
(
self
):
def
fill_cache_
(
self
):
self
.
row
()
self
.
row
()
...
@@ -421,18 +394,12 @@ class SparseStorage(object):
...
@@ -421,18 +394,12 @@ class SparseStorage(object):
return
count
return
count
def
copy
(
self
):
def
copy
(
self
):
return
SparseStorage
(
return
SparseStorage
(
row
=
self
.
_row
,
rowptr
=
self
.
_rowptr
,
col
=
self
.
_col
,
row
=
self
.
_row
,
value
=
self
.
_value
,
rowptr
=
self
.
_rowptr
,
sparse_sizes
=
self
.
_sparse_sizes
,
col
=
self
.
_col
,
rowcount
=
self
.
_rowcount
,
colptr
=
self
.
_colptr
,
value
=
self
.
_value
,
colcount
=
self
.
_colcount
,
csr2csc
=
self
.
_csr2csc
,
sparse_sizes
=
self
.
_sparse_sizes
,
csc2csr
=
self
.
_csc2csr
,
is_sorted
=
True
)
rowcount
=
self
.
_rowcount
,
colptr
=
self
.
_colptr
,
colcount
=
self
.
_colcount
,
csr2csc
=
self
.
_csr2csc
,
csc2csr
=
self
.
_csc2csr
,
is_sorted
=
True
)
def
clone
(
self
):
def
clone
(
self
):
row
=
self
.
_row
row
=
self
.
_row
...
@@ -460,18 +427,11 @@ class SparseStorage(object):
...
@@ -460,18 +427,11 @@ class SparseStorage(object):
csc2csr
=
self
.
_csc2csr
csc2csr
=
self
.
_csc2csr
if
csc2csr
is
not
None
:
if
csc2csr
is
not
None
:
csc2csr
=
csc2csr
.
clone
()
csc2csr
=
csc2csr
.
clone
()
return
SparseStorage
(
return
SparseStorage
(
row
=
row
,
rowptr
=
rowptr
,
col
=
col
,
value
=
value
,
row
=
row
,
sparse_sizes
=
self
.
_sparse_sizes
,
rowptr
=
rowptr
,
rowcount
=
rowcount
,
colptr
=
colptr
,
col
=
col
,
colcount
=
colcount
,
csr2csc
=
csr2csc
,
value
=
value
,
csc2csr
=
csc2csr
,
is_sorted
=
True
)
sparse_sizes
=
self
.
_sparse_sizes
,
rowcount
=
rowcount
,
colptr
=
colptr
,
colcount
=
colcount
,
csr2csc
=
csr2csc
,
csc2csr
=
csc2csr
,
is_sorted
=
True
)
def
type_as
(
self
,
tensor
=
torch
.
Tensor
):
def
type_as
(
self
,
tensor
=
torch
.
Tensor
):
value
=
self
.
_value
value
=
self
.
_value
...
@@ -512,18 +472,11 @@ class SparseStorage(object):
...
@@ -512,18 +472,11 @@ class SparseStorage(object):
csc2csr
=
self
.
_csc2csr
csc2csr
=
self
.
_csc2csr
if
csc2csr
is
not
None
:
if
csc2csr
is
not
None
:
csc2csr
=
csc2csr
.
to
(
tensor
.
device
,
non_blocking
=
non_blocking
)
csc2csr
=
csc2csr
.
to
(
tensor
.
device
,
non_blocking
=
non_blocking
)
return
SparseStorage
(
return
SparseStorage
(
row
=
row
,
rowptr
=
rowptr
,
col
=
col
,
value
=
value
,
row
=
row
,
sparse_sizes
=
self
.
_sparse_sizes
,
rowptr
=
rowptr
,
rowcount
=
rowcount
,
colptr
=
colptr
,
col
=
col
,
colcount
=
colcount
,
csr2csc
=
csr2csc
,
value
=
value
,
csc2csr
=
csc2csr
,
is_sorted
=
True
)
sparse_sizes
=
self
.
_sparse_sizes
,
rowcount
=
rowcount
,
colptr
=
colptr
,
colcount
=
colcount
,
csr2csc
=
csr2csc
,
csc2csr
=
csc2csr
,
is_sorted
=
True
)
def
pin_memory
(
self
):
def
pin_memory
(
self
):
row
=
self
.
_row
row
=
self
.
_row
...
@@ -551,18 +504,11 @@ class SparseStorage(object):
...
@@ -551,18 +504,11 @@ class SparseStorage(object):
csc2csr
=
self
.
_csc2csr
csc2csr
=
self
.
_csc2csr
if
csc2csr
is
not
None
:
if
csc2csr
is
not
None
:
csc2csr
=
csc2csr
.
pin_memory
()
csc2csr
=
csc2csr
.
pin_memory
()
return
SparseStorage
(
return
SparseStorage
(
row
=
row
,
rowptr
=
rowptr
,
col
=
col
,
value
=
value
,
row
=
row
,
sparse_sizes
=
self
.
_sparse_sizes
,
rowptr
=
rowptr
,
rowcount
=
rowcount
,
colptr
=
colptr
,
col
=
col
,
colcount
=
colcount
,
csr2csc
=
csr2csc
,
value
=
value
,
csc2csr
=
csc2csr
,
is_sorted
=
True
)
sparse_sizes
=
self
.
_sparse_sizes
,
rowcount
=
rowcount
,
colptr
=
colptr
,
colcount
=
colcount
,
csr2csc
=
csr2csc
,
csc2csr
=
csc2csr
,
is_sorted
=
True
)
def
is_pinned
(
self
)
->
bool
:
def
is_pinned
(
self
)
->
bool
:
is_pinned
=
True
is_pinned
=
True
...
...
torch_sparse/transpose.py
View file @
e6f5c3f0
...
@@ -4,7 +4,6 @@ from torch_sparse.storage import SparseStorage
...
@@ -4,7 +4,6 @@ from torch_sparse.storage import SparseStorage
from
torch_sparse.tensor
import
SparseTensor
from
torch_sparse.tensor
import
SparseTensor
@
torch
.
jit
.
script
def
t
(
src
:
SparseTensor
)
->
SparseTensor
:
def
t
(
src
:
SparseTensor
)
->
SparseTensor
:
csr2csc
=
src
.
storage
.
csr2csc
()
csr2csc
=
src
.
storage
.
csr2csc
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment