Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
e6f5c3f0
"src/kernel/vscode:/vscode.git/clone" did not exist on "35bed2a9d04d01a0b11b98e59710f21a6e702ef7"
Commit
e6f5c3f0
authored
Feb 19, 2020
by
rusty1s
Browse files
removed jit in external functions due to wrong caching behaviour
parent
3c259af5
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
53 additions
and
169 deletions
+53
-169
torch_sparse/add.py
torch_sparse/add.py
+0
-4
torch_sparse/cat.py
torch_sparse/cat.py
+0
-2
torch_sparse/diag.py
torch_sparse/diag.py
+10
-31
torch_sparse/index_select.py
torch_sparse/index_select.py
+0
-2
torch_sparse/masked_select.py
torch_sparse/masked_select.py
+0
-2
torch_sparse/matmul.py
torch_sparse/matmul.py
+3
-18
torch_sparse/metis.py
torch_sparse/metis.py
+0
-1
torch_sparse/mul.py
torch_sparse/mul.py
+0
-4
torch_sparse/narrow.py
torch_sparse/narrow.py
+0
-2
torch_sparse/permute.py
torch_sparse/permute.py
+0
-1
torch_sparse/reduce.py
torch_sparse/reduce.py
+0
-5
torch_sparse/select.py
torch_sparse/select.py
+0
-2
torch_sparse/storage.py
torch_sparse/storage.py
+40
-94
torch_sparse/transpose.py
torch_sparse/transpose.py
+0
-1
No files found.
torch_sparse/add.py
View file @
e6f5c3f0
...
...
@@ -5,7 +5,6 @@ from torch_scatter import gather_csr
from
torch_sparse.tensor
import
SparseTensor
@
torch
.
jit
.
script
def
add
(
src
:
SparseTensor
,
other
:
torch
.
Tensor
)
->
SparseTensor
:
rowptr
,
col
,
value
=
src
.
csr
()
if
other
.
size
(
0
)
==
src
.
size
(
0
)
and
other
.
size
(
1
)
==
1
:
# Row-wise...
...
...
@@ -24,7 +23,6 @@ def add(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
return
src
.
set_value
(
value
,
layout
=
'coo'
)
@
torch
.
jit
.
script
def
add_
(
src
:
SparseTensor
,
other
:
torch
.
Tensor
)
->
SparseTensor
:
rowptr
,
col
,
value
=
src
.
csr
()
if
other
.
size
(
0
)
==
src
.
size
(
0
)
and
other
.
size
(
1
)
==
1
:
# Row-wise...
...
...
@@ -44,7 +42,6 @@ def add_(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
return
src
.
set_value_
(
value
,
layout
=
'coo'
)
@
torch
.
jit
.
script
def
add_nnz
(
src
:
SparseTensor
,
other
:
torch
.
Tensor
,
layout
:
Optional
[
str
]
=
None
)
->
SparseTensor
:
value
=
src
.
storage
.
value
()
...
...
@@ -55,7 +52,6 @@ def add_nnz(src: SparseTensor, other: torch.Tensor,
return
src
.
set_value
(
value
,
layout
=
layout
)
@
torch
.
jit
.
script
def
add_nnz_
(
src
:
SparseTensor
,
other
:
torch
.
Tensor
,
layout
:
Optional
[
str
]
=
None
)
->
SparseTensor
:
value
=
src
.
storage
.
value
()
...
...
torch_sparse/cat.py
View file @
e6f5c3f0
...
...
@@ -5,7 +5,6 @@ from torch_sparse.storage import SparseStorage
from
torch_sparse.tensor
import
SparseTensor
@
torch
.
jit
.
script
def
cat
(
tensors
:
List
[
SparseTensor
],
dim
:
int
)
->
SparseTensor
:
assert
len
(
tensors
)
>
0
if
dim
<
0
:
...
...
@@ -142,7 +141,6 @@ def cat(tensors: List[SparseTensor], dim: int) -> SparseTensor:
'[{-tensors[0].dim()}, {tensors[0].dim() - 1}], but got {dim}.'
)
@
torch
.
jit
.
script
def
cat_diag
(
tensors
:
List
[
SparseTensor
])
->
SparseTensor
:
assert
len
(
tensors
)
>
0
...
...
torch_sparse/diag.py
View file @
e6f5c3f0
...
...
@@ -5,7 +5,6 @@ from torch_sparse.storage import SparseStorage
from
torch_sparse.tensor
import
SparseTensor
@
torch
.
jit
.
script
def
remove_diag
(
src
:
SparseTensor
,
k
:
int
=
0
)
->
SparseTensor
:
row
,
col
,
value
=
src
.
coo
()
inv_mask
=
row
!=
col
if
k
==
0
else
row
!=
(
col
-
k
)
...
...
@@ -25,24 +24,14 @@ def remove_diag(src: SparseTensor, k: int = 0) -> SparseTensor:
colcount
=
colcount
.
clone
()
colcount
[
col
[
mask
]]
-=
1
storage
=
SparseStorage
(
row
=
new_row
,
rowptr
=
None
,
col
=
new_col
,
value
=
value
,
sparse_sizes
=
src
.
sparse_sizes
(),
rowcount
=
rowcount
,
colptr
=
None
,
colcount
=
colcount
,
csr2csc
=
None
,
csc2csr
=
None
,
is_sorted
=
True
)
storage
=
SparseStorage
(
row
=
new_row
,
rowptr
=
None
,
col
=
new_col
,
value
=
value
,
sparse_sizes
=
src
.
sparse_sizes
(),
rowcount
=
rowcount
,
colptr
=
None
,
colcount
=
colcount
,
csr2csc
=
None
,
csc2csr
=
None
,
is_sorted
=
True
)
return
src
.
from_storage
(
storage
)
@
torch
.
jit
.
script
def
set_diag
(
src
:
SparseTensor
,
values
:
Optional
[
torch
.
Tensor
]
=
None
,
def
set_diag
(
src
:
SparseTensor
,
values
:
Optional
[
torch
.
Tensor
]
=
None
,
k
:
int
=
0
)
->
SparseTensor
:
src
=
remove_diag
(
src
,
k
=
k
)
row
,
col
,
value
=
src
.
coo
()
...
...
@@ -69,8 +58,7 @@ def set_diag(src: SparseTensor,
if
values
is
not
None
:
new_value
[
inv_mask
]
=
values
else
:
new_value
[
inv_mask
]
=
torch
.
ones
((
num_diag
,
),
dtype
=
value
.
dtype
,
new_value
[
inv_mask
]
=
torch
.
ones
((
num_diag
,
),
dtype
=
value
.
dtype
,
device
=
value
.
device
)
rowcount
=
src
.
storage
.
_rowcount
...
...
@@ -83,22 +71,13 @@ def set_diag(src: SparseTensor,
colcount
=
colcount
.
clone
()
colcount
[
start
+
k
:
start
+
num_diag
+
k
]
+=
1
storage
=
SparseStorage
(
row
=
new_row
,
rowptr
=
None
,
col
=
new_col
,
value
=
new_value
,
sparse_sizes
=
src
.
sparse_sizes
(),
rowcount
=
rowcount
,
colptr
=
None
,
colcount
=
colcount
,
csr2csc
=
None
,
csc2csr
=
None
,
is_sorted
=
True
)
storage
=
SparseStorage
(
row
=
new_row
,
rowptr
=
None
,
col
=
new_col
,
value
=
new_value
,
sparse_sizes
=
src
.
sparse_sizes
(),
rowcount
=
rowcount
,
colptr
=
None
,
colcount
=
colcount
,
csr2csc
=
None
,
csc2csr
=
None
,
is_sorted
=
True
)
return
src
.
from_storage
(
storage
)
@
torch
.
jit
.
script
def
fill_diag
(
src
:
SparseTensor
,
fill_value
:
int
,
k
:
int
=
0
)
->
SparseTensor
:
num_diag
=
min
(
src
.
sparse_size
(
0
),
src
.
sparse_size
(
1
)
-
k
)
if
k
<
0
:
...
...
torch_sparse/index_select.py
View file @
e6f5c3f0
...
...
@@ -6,7 +6,6 @@ from torch_sparse.storage import SparseStorage, get_layout
from
torch_sparse.tensor
import
SparseTensor
@
torch
.
jit
.
script
def
index_select
(
src
:
SparseTensor
,
dim
:
int
,
idx
:
torch
.
Tensor
)
->
SparseTensor
:
dim
=
src
.
dim
()
+
dim
if
dim
<
0
else
dim
...
...
@@ -79,7 +78,6 @@ def index_select(src: SparseTensor, dim: int,
raise
ValueError
@
torch
.
jit
.
script
def
index_select_nnz
(
src
:
SparseTensor
,
idx
:
torch
.
Tensor
,
layout
:
Optional
[
str
]
=
None
)
->
SparseTensor
:
assert
idx
.
dim
()
==
1
...
...
torch_sparse/masked_select.py
View file @
e6f5c3f0
...
...
@@ -5,7 +5,6 @@ from torch_sparse.storage import SparseStorage, get_layout
from
torch_sparse.tensor
import
SparseTensor
@
torch
.
jit
.
script
def
masked_select
(
src
:
SparseTensor
,
dim
:
int
,
mask
:
torch
.
Tensor
)
->
SparseTensor
:
dim
=
src
.
dim
()
+
dim
if
dim
<
0
else
dim
...
...
@@ -73,7 +72,6 @@ def masked_select(src: SparseTensor, dim: int,
raise
ValueError
@
torch
.
jit
.
script
def
masked_select_nnz
(
src
:
SparseTensor
,
mask
:
torch
.
Tensor
,
layout
:
Optional
[
str
]
=
None
)
->
SparseTensor
:
assert
mask
.
dim
()
==
1
...
...
torch_sparse/matmul.py
View file @
e6f5c3f0
...
...
@@ -4,7 +4,6 @@ import torch
from
torch_sparse.tensor
import
SparseTensor
@
torch
.
jit
.
script
def
spmm_sum
(
src
:
SparseTensor
,
other
:
torch
.
Tensor
)
->
torch
.
Tensor
:
rowptr
,
col
,
value
=
src
.
csr
()
...
...
@@ -24,12 +23,10 @@ def spmm_sum(src: SparseTensor, other: torch.Tensor) -> torch.Tensor:
csr2csc
,
other
)
@
torch
.
jit
.
script
def
spmm_add
(
src
:
SparseTensor
,
other
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
spmm_sum
(
src
,
other
)
@
torch
.
jit
.
script
def
spmm_mean
(
src
:
SparseTensor
,
other
:
torch
.
Tensor
)
->
torch
.
Tensor
:
rowptr
,
col
,
value
=
src
.
csr
()
...
...
@@ -51,21 +48,18 @@ def spmm_mean(src: SparseTensor, other: torch.Tensor) -> torch.Tensor:
colptr
,
csr2csc
,
other
)
@
torch
.
jit
.
script
def
spmm_min
(
src
:
SparseTensor
,
other
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
rowptr
,
col
,
value
=
src
.
csr
()
return
torch
.
ops
.
torch_sparse
.
spmm_min
(
rowptr
,
col
,
value
,
other
)
@
torch
.
jit
.
script
def
spmm_max
(
src
:
SparseTensor
,
other
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
rowptr
,
col
,
value
=
src
.
csr
()
return
torch
.
ops
.
torch_sparse
.
spmm_max
(
rowptr
,
col
,
value
,
other
)
@
torch
.
jit
.
script
def
spmm
(
src
:
SparseTensor
,
other
:
torch
.
Tensor
,
reduce
:
str
=
"sum"
)
->
torch
.
Tensor
:
if
reduce
==
'sum'
or
reduce
==
'add'
:
...
...
@@ -80,7 +74,6 @@ def spmm(src: SparseTensor, other: torch.Tensor,
raise
ValueError
@
torch
.
jit
.
script
def
spspmm_sum
(
src
:
SparseTensor
,
other
:
SparseTensor
)
->
SparseTensor
:
assert
src
.
sparse_size
(
1
)
==
other
.
sparse_size
(
0
)
rowptrA
,
colA
,
valueA
=
src
.
csr
()
...
...
@@ -88,21 +81,14 @@ def spspmm_sum(src: SparseTensor, other: SparseTensor) -> SparseTensor:
M
,
K
=
src
.
sparse_size
(
0
),
other
.
sparse_size
(
1
)
rowptrC
,
colC
,
valueC
=
torch
.
ops
.
torch_sparse
.
spspmm_sum
(
rowptrA
,
colA
,
valueA
,
rowptrB
,
colB
,
valueB
,
K
)
return
SparseTensor
(
row
=
None
,
rowptr
=
rowptrC
,
col
=
colC
,
value
=
valueC
,
sparse_sizes
=
torch
.
Size
([
M
,
K
]),
is_sorted
=
True
)
return
SparseTensor
(
row
=
None
,
rowptr
=
rowptrC
,
col
=
colC
,
value
=
valueC
,
sparse_sizes
=
torch
.
Size
([
M
,
K
]),
is_sorted
=
True
)
@
torch
.
jit
.
script
def
spspmm_add
(
src
:
SparseTensor
,
other
:
SparseTensor
)
->
SparseTensor
:
return
spspmm_sum
(
src
,
other
)
@
torch
.
jit
.
script
def
spspmm
(
src
:
SparseTensor
,
other
:
SparseTensor
,
reduce
:
str
=
"sum"
)
->
SparseTensor
:
if
reduce
==
'sum'
or
reduce
==
'add'
:
...
...
@@ -113,8 +99,7 @@ def spspmm(src: SparseTensor, other: SparseTensor,
raise
ValueError
def
matmul
(
src
:
SparseTensor
,
other
:
Union
[
torch
.
Tensor
,
SparseTensor
],
def
matmul
(
src
:
SparseTensor
,
other
:
Union
[
torch
.
Tensor
,
SparseTensor
],
reduce
:
str
=
"sum"
):
if
torch
.
is_tensor
(
other
):
return
spmm
(
src
,
other
,
reduce
)
...
...
torch_sparse/metis.py
View file @
e6f5c3f0
...
...
@@ -5,7 +5,6 @@ from torch_sparse.tensor import SparseTensor
from
torch_sparse.permute
import
permute
@
torch
.
jit
.
script
def
partition_kway
(
src
:
SparseTensor
,
num_parts
:
int
)
->
Tuple
[
SparseTensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
...
...
torch_sparse/mul.py
View file @
e6f5c3f0
...
...
@@ -5,7 +5,6 @@ from torch_scatter import gather_csr
from
torch_sparse.tensor
import
SparseTensor
@
torch
.
jit
.
script
def
mul
(
src
:
SparseTensor
,
other
:
torch
.
Tensor
)
->
SparseTensor
:
rowptr
,
col
,
value
=
src
.
csr
()
if
other
.
size
(
0
)
==
src
.
size
(
0
)
and
other
.
size
(
1
)
==
1
:
# Row-wise...
...
...
@@ -25,7 +24,6 @@ def mul(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
return
src
.
set_value
(
value
,
layout
=
'coo'
)
@
torch
.
jit
.
script
def
mul_
(
src
:
SparseTensor
,
other
:
torch
.
Tensor
)
->
SparseTensor
:
rowptr
,
col
,
value
=
src
.
csr
()
if
other
.
size
(
0
)
==
src
.
size
(
0
)
and
other
.
size
(
1
)
==
1
:
# Row-wise...
...
...
@@ -45,7 +43,6 @@ def mul_(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
return
src
.
set_value_
(
value
,
layout
=
'coo'
)
@
torch
.
jit
.
script
def
mul_nnz
(
src
:
SparseTensor
,
other
:
torch
.
Tensor
,
layout
:
Optional
[
str
]
=
None
)
->
SparseTensor
:
value
=
src
.
storage
.
value
()
...
...
@@ -56,7 +53,6 @@ def mul_nnz(src: SparseTensor, other: torch.Tensor,
return
src
.
set_value
(
value
,
layout
=
layout
)
@
torch
.
jit
.
script
def
mul_nnz_
(
src
:
SparseTensor
,
other
:
torch
.
Tensor
,
layout
:
Optional
[
str
]
=
None
)
->
SparseTensor
:
value
=
src
.
storage
.
value
()
...
...
torch_sparse/narrow.py
View file @
e6f5c3f0
...
...
@@ -5,7 +5,6 @@ from torch_sparse.storage import SparseStorage
from
torch_sparse.tensor
import
SparseTensor
@
torch
.
jit
.
script
def
narrow
(
src
:
SparseTensor
,
dim
:
int
,
start
:
int
,
length
:
int
)
->
SparseTensor
:
if
dim
<
0
:
...
...
@@ -80,7 +79,6 @@ def narrow(src: SparseTensor, dim: int, start: int,
raise
ValueError
@
torch
.
jit
.
script
def
__narrow_diag__
(
src
:
SparseTensor
,
start
:
Tuple
[
int
,
int
],
length
:
Tuple
[
int
,
int
])
->
SparseTensor
:
# This function builds the inverse operation of `cat_diag` and should hence
...
...
torch_sparse/permute.py
View file @
e6f5c3f0
...
...
@@ -3,7 +3,6 @@ from torch_sparse.storage import SparseStorage
from
torch_sparse.tensor
import
SparseTensor
@
torch
.
jit
.
script
def
permute
(
src
:
SparseTensor
,
perm
:
torch
.
Tensor
)
->
SparseTensor
:
assert
src
.
is_symmetric
()
...
...
torch_sparse/reduce.py
View file @
e6f5c3f0
...
...
@@ -5,7 +5,6 @@ from torch_scatter import scatter, segment_csr
from
torch_sparse.tensor
import
SparseTensor
@
torch
.
jit
.
script
def
reduction
(
src
:
SparseTensor
,
dim
:
Optional
[
int
]
=
None
,
reduce
:
str
=
'sum'
)
->
torch
.
Tensor
:
value
=
src
.
storage
.
value
()
...
...
@@ -68,22 +67,18 @@ def reduction(src: SparseTensor, dim: Optional[int] = None,
raise
ValueError
@
torch
.
jit
.
script
def
sum
(
src
:
SparseTensor
,
dim
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
return
reduction
(
src
,
dim
,
reduce
=
'sum'
)
@
torch
.
jit
.
script
def
mean
(
src
:
SparseTensor
,
dim
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
return
reduction
(
src
,
dim
,
reduce
=
'mean'
)
@
torch
.
jit
.
script
def
min
(
src
:
SparseTensor
,
dim
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
return
reduction
(
src
,
dim
,
reduce
=
'min'
)
@
torch
.
jit
.
script
def
max
(
src
:
SparseTensor
,
dim
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
return
reduction
(
src
,
dim
,
reduce
=
'max'
)
...
...
torch_sparse/select.py
View file @
e6f5c3f0
import
torch
from
torch_sparse.tensor
import
SparseTensor
from
torch_sparse.narrow
import
narrow
@
torch
.
jit
.
script
def
select
(
src
:
SparseTensor
,
dim
:
int
,
idx
:
int
)
->
SparseTensor
:
return
narrow
(
src
,
dim
,
start
=
idx
,
length
=
1
)
...
...
torch_sparse/storage.py
View file @
e6f5c3f0
...
...
@@ -30,8 +30,7 @@ class SparseStorage(object):
_csr2csc
:
Optional
[
torch
.
Tensor
]
_csc2csr
:
Optional
[
torch
.
Tensor
]
def
__init__
(
self
,
row
:
Optional
[
torch
.
Tensor
]
=
None
,
def
__init__
(
self
,
row
:
Optional
[
torch
.
Tensor
]
=
None
,
rowptr
:
Optional
[
torch
.
Tensor
]
=
None
,
col
:
Optional
[
torch
.
Tensor
]
=
None
,
value
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
@@ -192,8 +191,7 @@ class SparseStorage(object):
def
value
(
self
)
->
Optional
[
torch
.
Tensor
]:
return
self
.
_value
def
set_value_
(
self
,
value
:
Optional
[
torch
.
Tensor
],
def
set_value_
(
self
,
value
:
Optional
[
torch
.
Tensor
],
layout
:
Optional
[
str
]
=
None
):
if
value
is
not
None
:
if
get_layout
(
layout
)
==
'csc'
:
...
...
@@ -205,8 +203,7 @@ class SparseStorage(object):
self
.
_value
=
value
return
self
def
set_value
(
self
,
value
:
Optional
[
torch
.
Tensor
],
def
set_value
(
self
,
value
:
Optional
[
torch
.
Tensor
],
layout
:
Optional
[
str
]
=
None
):
if
value
is
not
None
:
if
get_layout
(
layout
)
==
'csc'
:
...
...
@@ -215,18 +212,11 @@ class SparseStorage(object):
assert
value
.
device
==
self
.
_col
.
device
assert
value
.
size
(
0
)
==
self
.
_col
.
numel
()
return
SparseStorage
(
row
=
self
.
_row
,
rowptr
=
self
.
_rowptr
,
col
=
self
.
_col
,
value
=
value
,
sparse_sizes
=
self
.
_sparse_sizes
,
rowcount
=
self
.
_rowcount
,
colptr
=
self
.
_colptr
,
colcount
=
self
.
_colcount
,
csr2csc
=
self
.
_csr2csc
,
csc2csr
=
self
.
_csc2csr
,
is_sorted
=
True
)
return
SparseStorage
(
row
=
self
.
_row
,
rowptr
=
self
.
_rowptr
,
col
=
self
.
_col
,
value
=
value
,
sparse_sizes
=
self
.
_sparse_sizes
,
rowcount
=
self
.
_rowcount
,
colptr
=
self
.
_colptr
,
colcount
=
self
.
_colcount
,
csr2csc
=
self
.
_csr2csc
,
csc2csr
=
self
.
_csc2csr
,
is_sorted
=
True
)
def
sparse_sizes
(
self
)
->
List
[
int
]:
return
self
.
_sparse_sizes
...
...
@@ -264,18 +254,11 @@ class SparseStorage(object):
if
colcount
is
not
None
:
colcount
=
colcount
[:
-
diff_1
]
return
SparseStorage
(
row
=
self
.
_row
,
rowptr
=
rowptr
,
col
=
self
.
_col
,
value
=
self
.
_value
,
sparse_sizes
=
sparse_sizes
,
rowcount
=
rowcount
,
colptr
=
colptr
,
colcount
=
colcount
,
csr2csc
=
self
.
_csr2csc
,
csc2csr
=
self
.
_csc2csr
,
is_sorted
=
True
)
return
SparseStorage
(
row
=
self
.
_row
,
rowptr
=
rowptr
,
col
=
self
.
_col
,
value
=
self
.
_value
,
sparse_sizes
=
sparse_sizes
,
rowcount
=
rowcount
,
colptr
=
colptr
,
colcount
=
colcount
,
csr2csc
=
self
.
_csr2csc
,
csc2csr
=
self
.
_csc2csr
,
is_sorted
=
True
)
def
has_rowcount
(
self
)
->
bool
:
return
self
.
_rowcount
is
not
None
...
...
@@ -320,10 +303,8 @@ class SparseStorage(object):
if
colptr
is
not
None
:
colcount
=
colptr
[
1
:]
-
colptr
[:
-
1
]
else
:
colcount
=
scatter_add
(
torch
.
ones_like
(
self
.
_col
),
self
.
_col
,
dim_size
=
self
.
_sparse_sizes
[
1
])
colcount
=
scatter_add
(
torch
.
ones_like
(
self
.
_col
),
self
.
_col
,
dim_size
=
self
.
_sparse_sizes
[
1
])
self
.
_colcount
=
colcount
return
colcount
...
...
@@ -375,18 +356,10 @@ class SparseStorage(object):
value
=
segment_csr
(
value
,
ptr
,
reduce
=
reduce
)
value
=
value
[
0
]
if
isinstance
(
value
,
tuple
)
else
value
return
SparseStorage
(
row
=
row
,
rowptr
=
None
,
col
=
col
,
value
=
value
,
sparse_sizes
=
self
.
_sparse_sizes
,
rowcount
=
None
,
colptr
=
None
,
colcount
=
None
,
csr2csc
=
None
,
csc2csr
=
None
,
is_sorted
=
True
)
return
SparseStorage
(
row
=
row
,
rowptr
=
None
,
col
=
col
,
value
=
value
,
sparse_sizes
=
self
.
_sparse_sizes
,
rowcount
=
None
,
colptr
=
None
,
colcount
=
None
,
csr2csc
=
None
,
csc2csr
=
None
,
is_sorted
=
True
)
def
fill_cache_
(
self
):
self
.
row
()
...
...
@@ -421,18 +394,12 @@ class SparseStorage(object):
return
count
def
copy
(
self
):
return
SparseStorage
(
row
=
self
.
_row
,
rowptr
=
self
.
_rowptr
,
col
=
self
.
_col
,
value
=
self
.
_value
,
sparse_sizes
=
self
.
_sparse_sizes
,
rowcount
=
self
.
_rowcount
,
colptr
=
self
.
_colptr
,
colcount
=
self
.
_colcount
,
csr2csc
=
self
.
_csr2csc
,
csc2csr
=
self
.
_csc2csr
,
is_sorted
=
True
)
return
SparseStorage
(
row
=
self
.
_row
,
rowptr
=
self
.
_rowptr
,
col
=
self
.
_col
,
value
=
self
.
_value
,
sparse_sizes
=
self
.
_sparse_sizes
,
rowcount
=
self
.
_rowcount
,
colptr
=
self
.
_colptr
,
colcount
=
self
.
_colcount
,
csr2csc
=
self
.
_csr2csc
,
csc2csr
=
self
.
_csc2csr
,
is_sorted
=
True
)
def
clone
(
self
):
row
=
self
.
_row
...
...
@@ -460,18 +427,11 @@ class SparseStorage(object):
csc2csr
=
self
.
_csc2csr
if
csc2csr
is
not
None
:
csc2csr
=
csc2csr
.
clone
()
return
SparseStorage
(
row
=
row
,
rowptr
=
rowptr
,
col
=
col
,
value
=
value
,
sparse_sizes
=
self
.
_sparse_sizes
,
rowcount
=
rowcount
,
colptr
=
colptr
,
colcount
=
colcount
,
csr2csc
=
csr2csc
,
csc2csr
=
csc2csr
,
is_sorted
=
True
)
return
SparseStorage
(
row
=
row
,
rowptr
=
rowptr
,
col
=
col
,
value
=
value
,
sparse_sizes
=
self
.
_sparse_sizes
,
rowcount
=
rowcount
,
colptr
=
colptr
,
colcount
=
colcount
,
csr2csc
=
csr2csc
,
csc2csr
=
csc2csr
,
is_sorted
=
True
)
def
type_as
(
self
,
tensor
=
torch
.
Tensor
):
value
=
self
.
_value
...
...
@@ -512,18 +472,11 @@ class SparseStorage(object):
csc2csr
=
self
.
_csc2csr
if
csc2csr
is
not
None
:
csc2csr
=
csc2csr
.
to
(
tensor
.
device
,
non_blocking
=
non_blocking
)
return
SparseStorage
(
row
=
row
,
rowptr
=
rowptr
,
col
=
col
,
value
=
value
,
sparse_sizes
=
self
.
_sparse_sizes
,
rowcount
=
rowcount
,
colptr
=
colptr
,
colcount
=
colcount
,
csr2csc
=
csr2csc
,
csc2csr
=
csc2csr
,
is_sorted
=
True
)
return
SparseStorage
(
row
=
row
,
rowptr
=
rowptr
,
col
=
col
,
value
=
value
,
sparse_sizes
=
self
.
_sparse_sizes
,
rowcount
=
rowcount
,
colptr
=
colptr
,
colcount
=
colcount
,
csr2csc
=
csr2csc
,
csc2csr
=
csc2csr
,
is_sorted
=
True
)
def
pin_memory
(
self
):
row
=
self
.
_row
...
...
@@ -551,18 +504,11 @@ class SparseStorage(object):
csc2csr
=
self
.
_csc2csr
if
csc2csr
is
not
None
:
csc2csr
=
csc2csr
.
pin_memory
()
return
SparseStorage
(
row
=
row
,
rowptr
=
rowptr
,
col
=
col
,
value
=
value
,
sparse_sizes
=
self
.
_sparse_sizes
,
rowcount
=
rowcount
,
colptr
=
colptr
,
colcount
=
colcount
,
csr2csc
=
csr2csc
,
csc2csr
=
csc2csr
,
is_sorted
=
True
)
return
SparseStorage
(
row
=
row
,
rowptr
=
rowptr
,
col
=
col
,
value
=
value
,
sparse_sizes
=
self
.
_sparse_sizes
,
rowcount
=
rowcount
,
colptr
=
colptr
,
colcount
=
colcount
,
csr2csc
=
csr2csc
,
csc2csr
=
csc2csr
,
is_sorted
=
True
)
def
is_pinned
(
self
)
->
bool
:
is_pinned
=
True
...
...
torch_sparse/transpose.py
View file @
e6f5c3f0
...
...
@@ -4,7 +4,6 @@ from torch_sparse.storage import SparseStorage
from
torch_sparse.tensor
import
SparseTensor
@
torch
.
jit
.
script
def
t
(
src
:
SparseTensor
)
->
SparseTensor
:
csr2csc
=
src
.
storage
.
csr2csc
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment