Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
15ff09d5
Commit
15ff09d5
authored
Feb 14, 2020
by
rusty1s
Browse files
removed load library calls
parent
60a29466
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
191 additions
and
106 deletions
+191
-106
torch_sparse/__init__.py
torch_sparse/__init__.py
+23
-23
torch_sparse/cat.py
torch_sparse/cat.py
+37
-13
torch_sparse/diag.py
torch_sparse/diag.py
+28
-15
torch_sparse/intersection.py
torch_sparse/intersection.py
+0
-0
torch_sparse/matmul.py
torch_sparse/matmul.py
+9
-10
torch_sparse/storage.py
torch_sparse/storage.py
+94
-45
No files found.
torch_sparse/__init__.py
View file @
15ff09d5
# flake8: noqa
import
importlib
import
importlib
import
os.path
as
osp
import
os.path
as
osp
...
@@ -9,8 +7,9 @@ __version__ = '0.5.1'
...
@@ -9,8 +7,9 @@ __version__ = '0.5.1'
expected_torch_version
=
(
1
,
4
)
expected_torch_version
=
(
1
,
4
)
try
:
try
:
torch
.
ops
.
load_library
(
importlib
.
machinery
.
PathFinder
().
find_spec
(
for
library
in
[
'_version'
,
'_convert'
,
'_diag'
,
'_spmm'
,
'_spspmm'
]:
'_version'
,
[
osp
.
dirname
(
__file__
)]).
origin
)
torch
.
ops
.
load_library
(
importlib
.
machinery
.
PathFinder
().
find_spec
(
library
,
[
osp
.
dirname
(
__file__
)]).
origin
)
except
OSError
as
e
:
except
OSError
as
e
:
if
'undefined symbol'
in
str
(
e
):
if
'undefined symbol'
in
str
(
e
):
major
,
minor
=
[
int
(
x
)
for
x
in
torch
.
__version__
.
split
(
'.'
)[:
2
]]
major
,
minor
=
[
int
(
x
)
for
x
in
torch
.
__version__
.
split
(
'.'
)[:
2
]]
...
@@ -40,26 +39,27 @@ if torch.version.cuda is not None: # pragma: no cover
...
@@ -40,26 +39,27 @@ if torch.version.cuda is not None: # pragma: no cover
f
'
{
major
}
.
{
minor
}
. Please reinstall the torch_sparse that '
f
'
{
major
}
.
{
minor
}
. Please reinstall the torch_sparse that '
f
'matches your PyTorch install.'
)
f
'matches your PyTorch install.'
)
from
.storage
import
SparseStorage
from
.storage
import
SparseStorage
# noqa: E4402
from
.tensor
import
SparseTensor
from
.tensor
import
SparseTensor
# noqa: E4402
from
.transpose
import
t
from
.transpose
import
t
# noqa: E4402
from
.narrow
import
narrow
,
__narrow_diag__
from
.narrow
import
narrow
,
__narrow_diag__
# noqa: E4402
from
.select
import
select
from
.select
import
select
# noqa: E4402
from
.index_select
import
index_select
,
index_select_nnz
from
.index_select
import
index_select
,
index_select_nnz
# noqa: E4402
from
.masked_select
import
masked_select
,
masked_select_nnz
from
.masked_select
import
masked_select
,
masked_select_nnz
# noqa: E4402
from
.diag
import
remove_diag
,
set_diag
,
fill_diag
from
.diag
import
remove_diag
,
set_diag
,
fill_diag
# noqa: E4402
from
.add
import
add
,
add_
,
add_nnz
,
add_nnz_
from
.add
import
add
,
add_
,
add_nnz
,
add_nnz_
# noqa: E4402
from
.mul
import
mul
,
mul_
,
mul_nnz
,
mul_nnz_
from
.mul
import
mul
,
mul_
,
mul_nnz
,
mul_nnz_
# noqa: E4402
from
.reduce
import
sum
,
mean
,
min
,
max
from
.reduce
import
sum
,
mean
,
min
,
max
# noqa: E4402
from
.matmul
import
matmul
from
.matmul
import
matmul
# noqa: E4402
from
.cat
import
cat
,
cat_diag
from
.cat
import
cat
,
cat_diag
# noqa: E4402
from
.convert
import
to_torch_sparse
,
from_torch_sparse
,
to_scipy
,
from_scipy
from
.convert
import
to_torch_sparse
,
from_torch_sparse
# noqa: E4402
from
.coalesce
import
coalesce
from
.convert
import
to_scipy
,
from_scipy
# noqa: E4402
from
.transpose
import
transpose
from
.coalesce
import
coalesce
# noqa: E4402
from
.eye
import
eye
from
.transpose
import
transpose
# noqa: E4402
from
.spmm
import
spmm
from
.eye
import
eye
# noqa: E4402
from
.spspmm
import
spspmm
from
.spmm
import
spmm
# noqa: E4402
from
.spspmm
import
spspmm
# noqa: E4402
__all__
=
[
__all__
=
[
'SparseStorage'
,
'SparseStorage'
,
...
...
torch_sparse/cat.py
View file @
15ff09d5
from
typing
import
List
,
Optional
from
typing
import
List
import
torch
import
torch
from
torch_sparse.storage
import
SparseStorage
from
torch_sparse.storage
import
SparseStorage
...
@@ -63,10 +63,18 @@ def cat(tensors: List[SparseTensor], dim: int) -> SparseTensor:
...
@@ -63,10 +63,18 @@ def cat(tensors: List[SparseTensor], dim: int) -> SparseTensor:
if
len
(
rowcounts
)
==
len
(
tensors
):
if
len
(
rowcounts
)
==
len
(
tensors
):
rowcount
=
torch
.
cat
(
rowcounts
,
dim
=
0
)
rowcount
=
torch
.
cat
(
rowcounts
,
dim
=
0
)
storage
=
SparseStorage
(
row
=
row
,
rowptr
=
rowptr
,
col
=
col
,
value
=
value
,
storage
=
SparseStorage
(
sparse_sizes
=
sparse_sizes
,
rowcount
=
rowcount
,
row
=
row
,
colptr
=
None
,
colcount
=
None
,
csr2csc
=
None
,
rowptr
=
rowptr
,
csc2csr
=
None
,
is_sorted
=
True
)
col
=
col
,
value
=
value
,
sparse_sizes
=
sparse_sizes
,
rowcount
=
rowcount
,
colptr
=
None
,
colcount
=
None
,
csr2csc
=
None
,
csc2csr
=
None
,
is_sorted
=
True
)
return
tensors
[
0
].
from_storage
(
storage
)
return
tensors
[
0
].
from_storage
(
storage
)
elif
dim
==
1
:
elif
dim
==
1
:
...
@@ -118,10 +126,18 @@ def cat(tensors: List[SparseTensor], dim: int) -> SparseTensor:
...
@@ -118,10 +126,18 @@ def cat(tensors: List[SparseTensor], dim: int) -> SparseTensor:
if
len
(
colcounts
)
==
len
(
tensors
):
if
len
(
colcounts
)
==
len
(
tensors
):
colcount
=
torch
.
cat
(
colcounts
,
dim
=
0
)
colcount
=
torch
.
cat
(
colcounts
,
dim
=
0
)
storage
=
SparseStorage
(
row
=
row
,
rowptr
=
None
,
col
=
col
,
value
=
value
,
storage
=
SparseStorage
(
sparse_sizes
=
sparse_sizes
,
rowcount
=
None
,
row
=
row
,
colptr
=
colptr
,
colcount
=
colcount
,
csr2csc
=
None
,
rowptr
=
None
,
csc2csr
=
None
,
is_sorted
=
False
)
col
=
col
,
value
=
value
,
sparse_sizes
=
sparse_sizes
,
rowcount
=
None
,
colptr
=
colptr
,
colcount
=
colcount
,
csr2csc
=
None
,
csc2csr
=
None
,
is_sorted
=
False
)
return
tensors
[
0
].
from_storage
(
storage
)
return
tensors
[
0
].
from_storage
(
storage
)
elif
dim
>
1
and
dim
<
tensors
[
0
].
dim
():
elif
dim
>
1
and
dim
<
tensors
[
0
].
dim
():
...
@@ -235,8 +251,16 @@ def cat_diag(tensors: List[SparseTensor]) -> SparseTensor:
...
@@ -235,8 +251,16 @@ def cat_diag(tensors: List[SparseTensor]) -> SparseTensor:
if
len
(
csc2csrs
)
==
len
(
tensors
):
if
len
(
csc2csrs
)
==
len
(
tensors
):
csc2csr
=
torch
.
cat
(
csc2csrs
,
dim
=
0
)
csc2csr
=
torch
.
cat
(
csc2csrs
,
dim
=
0
)
storage
=
SparseStorage
(
row
=
row
,
rowptr
=
rowptr
,
col
=
col
,
value
=
value
,
storage
=
SparseStorage
(
sparse_sizes
=
sparse_sizes
,
rowcount
=
rowcount
,
row
=
row
,
colptr
=
colptr
,
colcount
=
colcount
,
csr2csc
=
csr2csc
,
rowptr
=
rowptr
,
csc2csr
=
csc2csr
,
is_sorted
=
True
)
col
=
col
,
value
=
value
,
sparse_sizes
=
sparse_sizes
,
rowcount
=
rowcount
,
colptr
=
colptr
,
colcount
=
colcount
,
csr2csc
=
csr2csc
,
csc2csr
=
csc2csr
,
is_sorted
=
True
)
return
tensors
[
0
].
from_storage
(
storage
)
return
tensors
[
0
].
from_storage
(
storage
)
torch_sparse/diag.py
View file @
15ff09d5
import
importlib
import
os.path
as
osp
from
typing
import
Optional
from
typing
import
Optional
import
torch
import
torch
from
torch_sparse.storage
import
SparseStorage
from
torch_sparse.storage
import
SparseStorage
from
torch_sparse.tensor
import
SparseTensor
from
torch_sparse.tensor
import
SparseTensor
torch
.
ops
.
load_library
(
importlib
.
machinery
.
PathFinder
().
find_spec
(
'_diag'
,
[
osp
.
dirname
(
__file__
)]).
origin
)
@
torch
.
jit
.
script
@
torch
.
jit
.
script
def
remove_diag
(
src
:
SparseTensor
,
k
:
int
=
0
)
->
SparseTensor
:
def
remove_diag
(
src
:
SparseTensor
,
k
:
int
=
0
)
->
SparseTensor
:
...
@@ -30,15 +25,24 @@ def remove_diag(src: SparseTensor, k: int = 0) -> SparseTensor:
...
@@ -30,15 +25,24 @@ def remove_diag(src: SparseTensor, k: int = 0) -> SparseTensor:
colcount
=
colcount
.
clone
()
colcount
=
colcount
.
clone
()
colcount
[
col
[
mask
]]
-=
1
colcount
[
col
[
mask
]]
-=
1
storage
=
SparseStorage
(
row
=
new_row
,
rowptr
=
None
,
col
=
new_col
,
value
=
value
,
storage
=
SparseStorage
(
sparse_sizes
=
src
.
sparse_sizes
(),
rowcount
=
rowcount
,
row
=
new_row
,
colptr
=
None
,
colcount
=
colcount
,
csr2csc
=
None
,
rowptr
=
None
,
csc2csr
=
None
,
is_sorted
=
True
)
col
=
new_col
,
value
=
value
,
sparse_sizes
=
src
.
sparse_sizes
(),
rowcount
=
rowcount
,
colptr
=
None
,
colcount
=
colcount
,
csr2csc
=
None
,
csc2csr
=
None
,
is_sorted
=
True
)
return
src
.
from_storage
(
storage
)
return
src
.
from_storage
(
storage
)
@
torch
.
jit
.
script
@
torch
.
jit
.
script
def
set_diag
(
src
:
SparseTensor
,
values
:
Optional
[
torch
.
Tensor
]
=
None
,
def
set_diag
(
src
:
SparseTensor
,
values
:
Optional
[
torch
.
Tensor
]
=
None
,
k
:
int
=
0
)
->
SparseTensor
:
k
:
int
=
0
)
->
SparseTensor
:
src
=
remove_diag
(
src
,
k
=
k
)
src
=
remove_diag
(
src
,
k
=
k
)
row
,
col
,
value
=
src
.
coo
()
row
,
col
,
value
=
src
.
coo
()
...
@@ -65,7 +69,8 @@ def set_diag(src: SparseTensor, values: Optional[torch.Tensor] = None,
...
@@ -65,7 +69,8 @@ def set_diag(src: SparseTensor, values: Optional[torch.Tensor] = None,
if
values
is
not
None
:
if
values
is
not
None
:
new_value
[
inv_mask
]
=
values
new_value
[
inv_mask
]
=
values
else
:
else
:
new_value
[
inv_mask
]
=
torch
.
ones
((
num_diag
,
),
dtype
=
value
.
dtype
,
new_value
[
inv_mask
]
=
torch
.
ones
((
num_diag
,
),
dtype
=
value
.
dtype
,
device
=
value
.
device
)
device
=
value
.
device
)
rowcount
=
src
.
storage
.
_rowcount
rowcount
=
src
.
storage
.
_rowcount
...
@@ -78,10 +83,18 @@ def set_diag(src: SparseTensor, values: Optional[torch.Tensor] = None,
...
@@ -78,10 +83,18 @@ def set_diag(src: SparseTensor, values: Optional[torch.Tensor] = None,
colcount
=
colcount
.
clone
()
colcount
=
colcount
.
clone
()
colcount
[
start
+
k
:
start
+
num_diag
+
k
]
+=
1
colcount
[
start
+
k
:
start
+
num_diag
+
k
]
+=
1
storage
=
SparseStorage
(
row
=
new_row
,
rowptr
=
None
,
col
=
new_col
,
storage
=
SparseStorage
(
value
=
new_value
,
sparse_sizes
=
src
.
sparse_sizes
(),
row
=
new_row
,
rowcount
=
rowcount
,
colptr
=
None
,
colcount
=
colcount
,
rowptr
=
None
,
csr2csc
=
None
,
csc2csr
=
None
,
is_sorted
=
True
)
col
=
new_col
,
value
=
new_value
,
sparse_sizes
=
src
.
sparse_sizes
(),
rowcount
=
rowcount
,
colptr
=
None
,
colcount
=
colcount
,
csr2csc
=
None
,
csc2csr
=
None
,
is_sorted
=
True
)
return
src
.
from_storage
(
storage
)
return
src
.
from_storage
(
storage
)
...
...
torch_sparse/intersection.py
deleted
100644 → 0
View file @
60a29466
torch_sparse/matmul.py
View file @
15ff09d5
import
importlib
import
os.path
as
osp
from
typing
import
Union
,
Tuple
from
typing
import
Union
,
Tuple
import
torch
import
torch
from
torch_sparse.tensor
import
SparseTensor
from
torch_sparse.tensor
import
SparseTensor
torch
.
ops
.
load_library
(
importlib
.
machinery
.
PathFinder
().
find_spec
(
'_spmm'
,
[
osp
.
dirname
(
__file__
)]).
origin
)
torch
.
ops
.
load_library
(
importlib
.
machinery
.
PathFinder
().
find_spec
(
'_spspmm'
,
[
osp
.
dirname
(
__file__
)]).
origin
)
@
torch
.
jit
.
script
@
torch
.
jit
.
script
def
spmm_sum
(
src
:
SparseTensor
,
other
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
spmm_sum
(
src
:
SparseTensor
,
other
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
@@ -95,8 +88,13 @@ def spspmm_sum(src: SparseTensor, other: SparseTensor) -> SparseTensor:
...
@@ -95,8 +88,13 @@ def spspmm_sum(src: SparseTensor, other: SparseTensor) -> SparseTensor:
M
,
K
=
src
.
sparse_size
(
0
),
other
.
sparse_size
(
1
)
M
,
K
=
src
.
sparse_size
(
0
),
other
.
sparse_size
(
1
)
rowptrC
,
colC
,
valueC
=
torch
.
ops
.
torch_sparse
.
spspmm_sum
(
rowptrC
,
colC
,
valueC
=
torch
.
ops
.
torch_sparse
.
spspmm_sum
(
rowptrA
,
colA
,
valueA
,
rowptrB
,
colB
,
valueB
,
K
)
rowptrA
,
colA
,
valueA
,
rowptrB
,
colB
,
valueB
,
K
)
return
SparseTensor
(
row
=
None
,
rowptr
=
rowptrC
,
col
=
colC
,
value
=
valueC
,
return
SparseTensor
(
sparse_sizes
=
torch
.
Size
([
M
,
K
]),
is_sorted
=
True
)
row
=
None
,
rowptr
=
rowptrC
,
col
=
colC
,
value
=
valueC
,
sparse_sizes
=
torch
.
Size
([
M
,
K
]),
is_sorted
=
True
)
@
torch
.
jit
.
script
@
torch
.
jit
.
script
...
@@ -115,7 +113,8 @@ def spspmm(src: SparseTensor, other: SparseTensor,
...
@@ -115,7 +113,8 @@ def spspmm(src: SparseTensor, other: SparseTensor,
raise
ValueError
raise
ValueError
def
matmul
(
src
:
SparseTensor
,
other
:
Union
[
torch
.
Tensor
,
SparseTensor
],
def
matmul
(
src
:
SparseTensor
,
other
:
Union
[
torch
.
Tensor
,
SparseTensor
],
reduce
:
str
=
"sum"
):
reduce
:
str
=
"sum"
):
if
torch
.
is_tensor
(
other
):
if
torch
.
is_tensor
(
other
):
return
spmm
(
src
,
other
,
reduce
)
return
spmm
(
src
,
other
,
reduce
)
...
...
torch_sparse/storage.py
View file @
15ff09d5
import
warnings
import
warnings
import
importlib
import
os.path
as
osp
from
typing
import
Optional
,
List
from
typing
import
Optional
,
List
import
torch
import
torch
from
torch_scatter
import
segment_csr
,
scatter_add
from
torch_scatter
import
segment_csr
,
scatter_add
from
torch_sparse.utils
import
Final
from
torch_sparse.utils
import
Final
torch
.
ops
.
load_library
(
importlib
.
machinery
.
PathFinder
().
find_spec
(
'_convert'
,
[
osp
.
dirname
(
__file__
)]).
origin
)
layouts
:
Final
[
List
[
str
]]
=
[
'coo'
,
'csr'
,
'csc'
]
layouts
:
Final
[
List
[
str
]]
=
[
'coo'
,
'csr'
,
'csc'
]
...
@@ -35,7 +30,8 @@ class SparseStorage(object):
...
@@ -35,7 +30,8 @@ class SparseStorage(object):
_csr2csc
:
Optional
[
torch
.
Tensor
]
_csr2csc
:
Optional
[
torch
.
Tensor
]
_csc2csr
:
Optional
[
torch
.
Tensor
]
_csc2csr
:
Optional
[
torch
.
Tensor
]
def
__init__
(
self
,
row
:
Optional
[
torch
.
Tensor
]
=
None
,
def
__init__
(
self
,
row
:
Optional
[
torch
.
Tensor
]
=
None
,
rowptr
:
Optional
[
torch
.
Tensor
]
=
None
,
rowptr
:
Optional
[
torch
.
Tensor
]
=
None
,
col
:
Optional
[
torch
.
Tensor
]
=
None
,
col
:
Optional
[
torch
.
Tensor
]
=
None
,
value
:
Optional
[
torch
.
Tensor
]
=
None
,
value
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -196,7 +192,8 @@ class SparseStorage(object):
...
@@ -196,7 +192,8 @@ class SparseStorage(object):
def
value
(
self
)
->
Optional
[
torch
.
Tensor
]:
def
value
(
self
)
->
Optional
[
torch
.
Tensor
]:
return
self
.
_value
return
self
.
_value
def
set_value_
(
self
,
value
:
Optional
[
torch
.
Tensor
],
def
set_value_
(
self
,
value
:
Optional
[
torch
.
Tensor
],
layout
:
Optional
[
str
]
=
None
):
layout
:
Optional
[
str
]
=
None
):
if
value
is
not
None
:
if
value
is
not
None
:
if
get_layout
(
layout
)
==
'csc'
:
if
get_layout
(
layout
)
==
'csc'
:
...
@@ -208,7 +205,8 @@ class SparseStorage(object):
...
@@ -208,7 +205,8 @@ class SparseStorage(object):
self
.
_value
=
value
self
.
_value
=
value
return
self
return
self
def
set_value
(
self
,
value
:
Optional
[
torch
.
Tensor
],
def
set_value
(
self
,
value
:
Optional
[
torch
.
Tensor
],
layout
:
Optional
[
str
]
=
None
):
layout
:
Optional
[
str
]
=
None
):
if
value
is
not
None
:
if
value
is
not
None
:
if
get_layout
(
layout
)
==
'csc'
:
if
get_layout
(
layout
)
==
'csc'
:
...
@@ -217,11 +215,18 @@ class SparseStorage(object):
...
@@ -217,11 +215,18 @@ class SparseStorage(object):
assert
value
.
device
==
self
.
_col
.
device
assert
value
.
device
==
self
.
_col
.
device
assert
value
.
size
(
0
)
==
self
.
_col
.
numel
()
assert
value
.
size
(
0
)
==
self
.
_col
.
numel
()
return
SparseStorage
(
row
=
self
.
_row
,
rowptr
=
self
.
_rowptr
,
col
=
self
.
_col
,
return
SparseStorage
(
value
=
value
,
sparse_sizes
=
self
.
_sparse_sizes
,
row
=
self
.
_row
,
rowcount
=
self
.
_rowcount
,
colptr
=
self
.
_colptr
,
rowptr
=
self
.
_rowptr
,
colcount
=
self
.
_colcount
,
csr2csc
=
self
.
_csr2csc
,
col
=
self
.
_col
,
csc2csr
=
self
.
_csc2csr
,
is_sorted
=
True
)
value
=
value
,
sparse_sizes
=
self
.
_sparse_sizes
,
rowcount
=
self
.
_rowcount
,
colptr
=
self
.
_colptr
,
colcount
=
self
.
_colcount
,
csr2csc
=
self
.
_csr2csc
,
csc2csr
=
self
.
_csc2csr
,
is_sorted
=
True
)
def
sparse_sizes
(
self
)
->
List
[
int
]:
def
sparse_sizes
(
self
)
->
List
[
int
]:
return
self
.
_sparse_sizes
return
self
.
_sparse_sizes
...
@@ -259,11 +264,18 @@ class SparseStorage(object):
...
@@ -259,11 +264,18 @@ class SparseStorage(object):
if
colcount
is
not
None
:
if
colcount
is
not
None
:
colcount
=
colcount
[:
-
diff_1
]
colcount
=
colcount
[:
-
diff_1
]
return
SparseStorage
(
row
=
self
.
_row
,
rowptr
=
rowptr
,
col
=
self
.
_col
,
return
SparseStorage
(
value
=
self
.
_value
,
sparse_sizes
=
sparse_sizes
,
row
=
self
.
_row
,
rowcount
=
rowcount
,
colptr
=
colptr
,
rowptr
=
rowptr
,
colcount
=
colcount
,
csr2csc
=
self
.
_csr2csc
,
col
=
self
.
_col
,
csc2csr
=
self
.
_csc2csr
,
is_sorted
=
True
)
value
=
self
.
_value
,
sparse_sizes
=
sparse_sizes
,
rowcount
=
rowcount
,
colptr
=
colptr
,
colcount
=
colcount
,
csr2csc
=
self
.
_csr2csc
,
csc2csr
=
self
.
_csc2csr
,
is_sorted
=
True
)
def
has_rowcount
(
self
)
->
bool
:
def
has_rowcount
(
self
)
->
bool
:
return
self
.
_rowcount
is
not
None
return
self
.
_rowcount
is
not
None
...
@@ -308,8 +320,10 @@ class SparseStorage(object):
...
@@ -308,8 +320,10 @@ class SparseStorage(object):
if
colptr
is
not
None
:
if
colptr
is
not
None
:
colcount
=
colptr
[
1
:]
-
colptr
[:
-
1
]
colcount
=
colptr
[
1
:]
-
colptr
[:
-
1
]
else
:
else
:
colcount
=
scatter_add
(
torch
.
ones_like
(
self
.
_col
),
self
.
_col
,
colcount
=
scatter_add
(
dim_size
=
self
.
_sparse_sizes
[
1
])
torch
.
ones_like
(
self
.
_col
),
self
.
_col
,
dim_size
=
self
.
_sparse_sizes
[
1
])
self
.
_colcount
=
colcount
self
.
_colcount
=
colcount
return
colcount
return
colcount
...
@@ -361,10 +375,18 @@ class SparseStorage(object):
...
@@ -361,10 +375,18 @@ class SparseStorage(object):
value
=
segment_csr
(
value
,
ptr
,
reduce
=
reduce
)
value
=
segment_csr
(
value
,
ptr
,
reduce
=
reduce
)
value
=
value
[
0
]
if
isinstance
(
value
,
tuple
)
else
value
value
=
value
[
0
]
if
isinstance
(
value
,
tuple
)
else
value
return
SparseStorage
(
row
=
row
,
rowptr
=
None
,
col
=
col
,
value
=
value
,
return
SparseStorage
(
sparse_sizes
=
self
.
_sparse_sizes
,
rowcount
=
None
,
row
=
row
,
colptr
=
None
,
colcount
=
None
,
csr2csc
=
None
,
rowptr
=
None
,
csc2csr
=
None
,
is_sorted
=
True
)
col
=
col
,
value
=
value
,
sparse_sizes
=
self
.
_sparse_sizes
,
rowcount
=
None
,
colptr
=
None
,
colcount
=
None
,
csr2csc
=
None
,
csc2csr
=
None
,
is_sorted
=
True
)
def
fill_cache_
(
self
):
def
fill_cache_
(
self
):
self
.
row
()
self
.
row
()
...
@@ -399,12 +421,18 @@ class SparseStorage(object):
...
@@ -399,12 +421,18 @@ class SparseStorage(object):
return
count
return
count
def
copy
(
self
):
def
copy
(
self
):
return
SparseStorage
(
row
=
self
.
_row
,
rowptr
=
self
.
_rowptr
,
col
=
self
.
_col
,
return
SparseStorage
(
value
=
self
.
_value
,
row
=
self
.
_row
,
sparse_sizes
=
self
.
_sparse_sizes
,
rowptr
=
self
.
_rowptr
,
rowcount
=
self
.
_rowcount
,
colptr
=
self
.
_colptr
,
col
=
self
.
_col
,
colcount
=
self
.
_colcount
,
csr2csc
=
self
.
_csr2csc
,
value
=
self
.
_value
,
csc2csr
=
self
.
_csc2csr
,
is_sorted
=
True
)
sparse_sizes
=
self
.
_sparse_sizes
,
rowcount
=
self
.
_rowcount
,
colptr
=
self
.
_colptr
,
colcount
=
self
.
_colcount
,
csr2csc
=
self
.
_csr2csc
,
csc2csr
=
self
.
_csc2csr
,
is_sorted
=
True
)
def
clone
(
self
):
def
clone
(
self
):
row
=
self
.
_row
row
=
self
.
_row
...
@@ -432,11 +460,18 @@ class SparseStorage(object):
...
@@ -432,11 +460,18 @@ class SparseStorage(object):
csc2csr
=
self
.
_csc2csr
csc2csr
=
self
.
_csc2csr
if
csc2csr
is
not
None
:
if
csc2csr
is
not
None
:
csc2csr
=
csc2csr
.
clone
()
csc2csr
=
csc2csr
.
clone
()
return
SparseStorage
(
row
=
row
,
rowptr
=
rowptr
,
col
=
col
,
value
=
value
,
return
SparseStorage
(
sparse_sizes
=
self
.
_sparse_sizes
,
row
=
row
,
rowcount
=
rowcount
,
colptr
=
colptr
,
rowptr
=
rowptr
,
colcount
=
colcount
,
csr2csc
=
csr2csc
,
col
=
col
,
csc2csr
=
csc2csr
,
is_sorted
=
True
)
value
=
value
,
sparse_sizes
=
self
.
_sparse_sizes
,
rowcount
=
rowcount
,
colptr
=
colptr
,
colcount
=
colcount
,
csr2csc
=
csr2csc
,
csc2csr
=
csc2csr
,
is_sorted
=
True
)
def
type_as
(
self
,
tensor
=
torch
.
Tensor
):
def
type_as
(
self
,
tensor
=
torch
.
Tensor
):
value
=
self
.
_value
value
=
self
.
_value
...
@@ -477,11 +512,18 @@ class SparseStorage(object):
...
@@ -477,11 +512,18 @@ class SparseStorage(object):
csc2csr
=
self
.
_csc2csr
csc2csr
=
self
.
_csc2csr
if
csc2csr
is
not
None
:
if
csc2csr
is
not
None
:
csc2csr
=
csc2csr
.
to
(
tensor
.
device
,
non_blocking
=
non_blocking
)
csc2csr
=
csc2csr
.
to
(
tensor
.
device
,
non_blocking
=
non_blocking
)
return
SparseStorage
(
row
=
row
,
rowptr
=
rowptr
,
col
=
col
,
value
=
value
,
return
SparseStorage
(
sparse_sizes
=
self
.
_sparse_sizes
,
row
=
row
,
rowcount
=
rowcount
,
colptr
=
colptr
,
rowptr
=
rowptr
,
colcount
=
colcount
,
csr2csc
=
csr2csc
,
col
=
col
,
csc2csr
=
csc2csr
,
is_sorted
=
True
)
value
=
value
,
sparse_sizes
=
self
.
_sparse_sizes
,
rowcount
=
rowcount
,
colptr
=
colptr
,
colcount
=
colcount
,
csr2csc
=
csr2csc
,
csc2csr
=
csc2csr
,
is_sorted
=
True
)
def
pin_memory
(
self
):
def
pin_memory
(
self
):
row
=
self
.
_row
row
=
self
.
_row
...
@@ -509,11 +551,18 @@ class SparseStorage(object):
...
@@ -509,11 +551,18 @@ class SparseStorage(object):
csc2csr
=
self
.
_csc2csr
csc2csr
=
self
.
_csc2csr
if
csc2csr
is
not
None
:
if
csc2csr
is
not
None
:
csc2csr
=
csc2csr
.
pin_memory
()
csc2csr
=
csc2csr
.
pin_memory
()
return
SparseStorage
(
row
=
row
,
rowptr
=
rowptr
,
col
=
col
,
value
=
value
,
return
SparseStorage
(
sparse_sizes
=
self
.
_sparse_sizes
,
row
=
row
,
rowcount
=
rowcount
,
colptr
=
colptr
,
rowptr
=
rowptr
,
colcount
=
colcount
,
csr2csc
=
csr2csc
,
col
=
col
,
csc2csr
=
csc2csr
,
is_sorted
=
True
)
value
=
value
,
sparse_sizes
=
self
.
_sparse_sizes
,
rowcount
=
rowcount
,
colptr
=
colptr
,
colcount
=
colcount
,
csr2csc
=
csr2csc
,
csc2csr
=
csc2csr
,
is_sorted
=
True
)
def
is_pinned
(
self
)
->
bool
:
def
is_pinned
(
self
)
->
bool
:
is_pinned
=
True
is_pinned
=
True
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment