Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
f87afd09
Commit
f87afd09
authored
Jan 28, 2020
by
rusty1s
Browse files
tensor and storage mostly jittable
parent
631eee37
Changes
3
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
657 additions
and
413 deletions
+657
-413
test/test_jit.py
test/test_jit.py
+8
-4
torch_sparse/storage.py
torch_sparse/storage.py
+208
-70
torch_sparse/tensor.py
torch_sparse/tensor.py
+441
-339
No files found.
test/test_jit.py
View file @
f87afd09
...
@@ -30,9 +30,10 @@ class MyCell(torch.nn.Module):
...
@@ -30,9 +30,10 @@ class MyCell(torch.nn.Module):
self
.
linear
=
torch
.
nn
.
Linear
(
2
,
4
)
self
.
linear
=
torch
.
nn
.
Linear
(
2
,
4
)
# def forward(self, x: torch.Tensor, ptr: torch.Tensor) -> torch.Tensor:
# def forward(self, x: torch.Tensor, ptr: torch.Tensor) -> torch.Tensor:
def
forward
(
self
,
x
:
torch
.
Tensor
,
adj
:
SparseStorage
)
->
torch
.
Tensor
:
def
forward
(
self
,
x
:
torch
.
Tensor
,
adj
:
SparseTensor
)
->
torch
.
Tensor
:
out
,
_
=
torch
.
ops
.
torch_sparse_cpu
.
spmm
(
adj
.
rowptr
(),
adj
.
col
(),
None
,
out
,
_
=
torch
.
ops
.
torch_sparse_cpu
.
spmm
(
adj
.
storage
.
rowptr
(),
x
,
'sum'
)
adj
.
storage
.
col
(),
None
,
x
,
'sum'
)
return
out
return
out
...
@@ -67,7 +68,10 @@ def test_jit():
...
@@ -67,7 +68,10 @@ def test_jit():
rowptr
=
torch
.
tensor
([
0
,
3
,
6
,
9
])
rowptr
=
torch
.
tensor
([
0
,
3
,
6
,
9
])
col
=
torch
.
tensor
([
0
,
1
,
2
,
0
,
1
,
2
,
0
,
1
,
2
])
col
=
torch
.
tensor
([
0
,
1
,
2
,
0
,
1
,
2
,
0
,
1
,
2
])
adj
=
SparseStorage
(
rowptr
=
rowptr
,
col
=
col
)
adj
=
SparseTensor
(
rowptr
=
rowptr
,
col
=
col
)
scipy
=
adj
.
to_scipy
(
layout
=
'csr'
)
mat
=
SparseTensor
.
from_scipy
(
scipy
)
mat
.
fill_value_
(
2.3
)
# adj = {'rowptr': mat.storage.rowptr, 'col': mat.storage.col}
# adj = {'rowptr': mat.storage.rowptr, 'col': mat.storage.col}
# foo = Foo(mat.storage.rowptr, mat.storage.col)
# foo = Foo(mat.storage.rowptr, mat.storage.col)
...
...
torch_sparse/storage.py
View file @
f87afd09
import
warnings
import
warnings
from
typing
import
Optional
,
List
,
Dict
,
Union
,
Any
from
typing
import
Optional
,
List
import
torch
import
torch
from
torch_scatter
import
segment_csr
,
scatter_add
from
torch_scatter
import
segment_csr
,
scatter_add
from
torch_sparse.utils
import
Final
,
is_scalar
from
torch_sparse.utils
import
Final
# __cache__ = {'enabled': True}
# def is_cache_enabled():
# return __cache__['enabled']
# def set_cache_enabled(mode):
# __cache__['enabled'] = mode
# class no_cache(object):
# def __enter__(self):
# self.prev = is_cache_enabled()
# set_cache_enabled(False)
# def __exit__(self, *args):
# set_cache_enabled(self.prev)
# return False
# def __call__(self, func):
# def decorate_no_cache(*args, **kwargs):
# with self:
# return func(*args, **kwargs)
# return decorate_no_cache
def
optional
(
func
,
src
):
return
func
(
src
)
if
src
is
not
None
else
src
layouts
:
Final
[
List
[
str
]]
=
[
'coo'
,
'csr'
,
'csc'
]
layouts
:
Final
[
List
[
str
]]
=
[
'coo'
,
'csr'
,
'csc'
]
...
@@ -52,7 +23,7 @@ class SparseStorage(object):
...
@@ -52,7 +23,7 @@ class SparseStorage(object):
_rowptr
:
Optional
[
torch
.
Tensor
]
_rowptr
:
Optional
[
torch
.
Tensor
]
_col
:
torch
.
Tensor
_col
:
torch
.
Tensor
_value
:
Optional
[
torch
.
Tensor
]
_value
:
Optional
[
torch
.
Tensor
]
_sparse_size
:
List
[
int
]
_sparse_size
s
:
List
[
int
]
_rowcount
:
Optional
[
torch
.
Tensor
]
_rowcount
:
Optional
[
torch
.
Tensor
]
_colptr
:
Optional
[
torch
.
Tensor
]
_colptr
:
Optional
[
torch
.
Tensor
]
_colcount
:
Optional
[
torch
.
Tensor
]
_colcount
:
Optional
[
torch
.
Tensor
]
...
@@ -63,7 +34,7 @@ class SparseStorage(object):
...
@@ -63,7 +34,7 @@ class SparseStorage(object):
rowptr
:
Optional
[
torch
.
Tensor
]
=
None
,
rowptr
:
Optional
[
torch
.
Tensor
]
=
None
,
col
:
Optional
[
torch
.
Tensor
]
=
None
,
col
:
Optional
[
torch
.
Tensor
]
=
None
,
value
:
Optional
[
torch
.
Tensor
]
=
None
,
value
:
Optional
[
torch
.
Tensor
]
=
None
,
sparse_size
:
Optional
[
List
[
int
]]
=
None
,
sparse_size
s
:
Optional
[
List
[
int
]]
=
None
,
rowcount
:
Optional
[
torch
.
Tensor
]
=
None
,
rowcount
:
Optional
[
torch
.
Tensor
]
=
None
,
colptr
:
Optional
[
torch
.
Tensor
]
=
None
,
colptr
:
Optional
[
torch
.
Tensor
]
=
None
,
colcount
:
Optional
[
torch
.
Tensor
]
=
None
,
colcount
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -77,7 +48,7 @@ class SparseStorage(object):
...
@@ -77,7 +48,7 @@ class SparseStorage(object):
assert
col
.
dim
()
==
1
assert
col
.
dim
()
==
1
col
=
col
.
contiguous
()
col
=
col
.
contiguous
()
if
sparse_size
is
None
:
if
sparse_size
s
is
None
:
if
rowptr
is
not
None
:
if
rowptr
is
not
None
:
M
=
rowptr
.
numel
()
-
1
M
=
rowptr
.
numel
()
-
1
elif
row
is
not
None
:
elif
row
is
not
None
:
...
@@ -85,9 +56,9 @@ class SparseStorage(object):
...
@@ -85,9 +56,9 @@ class SparseStorage(object):
else
:
else
:
raise
ValueError
raise
ValueError
N
=
col
.
max
().
item
()
+
1
N
=
col
.
max
().
item
()
+
1
sparse_size
=
torch
.
Size
([
int
(
M
),
int
(
N
)])
sparse_size
s
=
torch
.
Size
([
int
(
M
),
int
(
N
)])
else
:
else
:
assert
len
(
sparse_size
)
==
2
assert
len
(
sparse_size
s
)
==
2
if
row
is
not
None
:
if
row
is
not
None
:
assert
row
.
dtype
==
torch
.
long
assert
row
.
dtype
==
torch
.
long
...
@@ -100,7 +71,7 @@ class SparseStorage(object):
...
@@ -100,7 +71,7 @@ class SparseStorage(object):
assert
rowptr
.
dtype
==
torch
.
long
assert
rowptr
.
dtype
==
torch
.
long
assert
rowptr
.
device
==
col
.
device
assert
rowptr
.
device
==
col
.
device
assert
rowptr
.
dim
()
==
1
assert
rowptr
.
dim
()
==
1
assert
rowptr
.
numel
()
-
1
==
sparse_size
[
0
]
assert
rowptr
.
numel
()
-
1
==
sparse_size
s
[
0
]
rowptr
=
rowptr
.
contiguous
()
rowptr
=
rowptr
.
contiguous
()
if
value
is
not
None
:
if
value
is
not
None
:
...
@@ -112,21 +83,21 @@ class SparseStorage(object):
...
@@ -112,21 +83,21 @@ class SparseStorage(object):
assert
rowcount
.
dtype
==
torch
.
long
assert
rowcount
.
dtype
==
torch
.
long
assert
rowcount
.
device
==
col
.
device
assert
rowcount
.
device
==
col
.
device
assert
rowcount
.
dim
()
==
1
assert
rowcount
.
dim
()
==
1
assert
rowcount
.
numel
()
==
sparse_size
[
0
]
assert
rowcount
.
numel
()
==
sparse_size
s
[
0
]
rowcount
=
rowcount
.
contiguous
()
rowcount
=
rowcount
.
contiguous
()
if
colptr
is
not
None
:
if
colptr
is
not
None
:
assert
colptr
.
dtype
==
torch
.
long
assert
colptr
.
dtype
==
torch
.
long
assert
colptr
.
device
==
col
.
device
assert
colptr
.
device
==
col
.
device
assert
colptr
.
dim
()
==
1
assert
colptr
.
dim
()
==
1
assert
colptr
.
numel
()
-
1
==
sparse_size
[
1
]
assert
colptr
.
numel
()
-
1
==
sparse_size
s
[
1
]
colptr
=
colptr
.
contiguous
()
colptr
=
colptr
.
contiguous
()
if
colcount
is
not
None
:
if
colcount
is
not
None
:
assert
colcount
.
dtype
==
torch
.
long
assert
colcount
.
dtype
==
torch
.
long
assert
colcount
.
device
==
col
.
device
assert
colcount
.
device
==
col
.
device
assert
colcount
.
dim
()
==
1
assert
colcount
.
dim
()
==
1
assert
colcount
.
numel
()
==
sparse_size
[
1
]
assert
colcount
.
numel
()
==
sparse_size
s
[
1
]
colcount
=
colcount
.
contiguous
()
colcount
=
colcount
.
contiguous
()
if
csr2csc
is
not
None
:
if
csr2csc
is
not
None
:
...
@@ -147,7 +118,7 @@ class SparseStorage(object):
...
@@ -147,7 +118,7 @@ class SparseStorage(object):
self
.
_rowptr
=
rowptr
self
.
_rowptr
=
rowptr
self
.
_col
=
col
self
.
_col
=
col
self
.
_value
=
value
self
.
_value
=
value
self
.
_sparse_size
=
sparse_size
self
.
_sparse_size
s
=
sparse_size
s
self
.
_rowcount
=
rowcount
self
.
_rowcount
=
rowcount
self
.
_colptr
=
colptr
self
.
_colptr
=
colptr
self
.
_colcount
=
colcount
self
.
_colcount
=
colcount
...
@@ -156,7 +127,7 @@ class SparseStorage(object):
...
@@ -156,7 +127,7 @@ class SparseStorage(object):
if
not
is_sorted
:
if
not
is_sorted
:
idx
=
col
.
new_zeros
(
col
.
numel
()
+
1
)
idx
=
col
.
new_zeros
(
col
.
numel
()
+
1
)
idx
[
1
:]
=
sparse_size
[
1
]
*
self
.
row
()
+
col
idx
[
1
:]
=
sparse_size
s
[
1
]
*
self
.
row
()
+
col
if
(
idx
[
1
:]
<
idx
[:
-
1
]).
any
():
if
(
idx
[
1
:]
<
idx
[:
-
1
]).
any
():
perm
=
idx
[
1
:].
argsort
()
perm
=
idx
[
1
:].
argsort
()
self
.
_row
=
self
.
row
()[
perm
]
self
.
_row
=
self
.
row
()[
perm
]
...
@@ -203,10 +174,10 @@ class SparseStorage(object):
...
@@ -203,10 +174,10 @@ class SparseStorage(object):
if
row
is
not
None
:
if
row
is
not
None
:
if
row
.
is_cuda
:
if
row
.
is_cuda
:
rowptr
=
torch
.
ops
.
torch_sparse_cuda
.
ind2ptr
(
rowptr
=
torch
.
ops
.
torch_sparse_cuda
.
ind2ptr
(
row
,
self
.
_sparse_size
[
0
])
row
,
self
.
_sparse_size
s
[
0
])
else
:
else
:
rowptr
=
torch
.
ops
.
torch_sparse_cpu
.
ind2ptr
(
rowptr
=
torch
.
ops
.
torch_sparse_cpu
.
ind2ptr
(
row
,
self
.
_sparse_size
[
0
])
row
,
self
.
_sparse_size
s
[
0
])
self
.
_rowptr
=
rowptr
self
.
_rowptr
=
rowptr
return
rowptr
return
rowptr
...
@@ -243,27 +214,22 @@ class SparseStorage(object):
...
@@ -243,27 +214,22 @@ class SparseStorage(object):
assert
value
.
size
(
0
)
==
self
.
_col
.
numel
()
assert
value
.
size
(
0
)
==
self
.
_col
.
numel
()
return
SparseStorage
(
row
=
self
.
_row
,
rowptr
=
self
.
_rowptr
,
col
=
self
.
_col
,
return
SparseStorage
(
row
=
self
.
_row
,
rowptr
=
self
.
_rowptr
,
col
=
self
.
_col
,
value
=
value
,
sparse_size
=
self
.
_sparse_size
,
value
=
value
,
sparse_size
s
=
self
.
_sparse_size
s
,
rowcount
=
self
.
_rowcount
,
colptr
=
self
.
_colptr
,
rowcount
=
self
.
_rowcount
,
colptr
=
self
.
_colptr
,
colcount
=
self
.
_colcount
,
csr2csc
=
self
.
_csr2csc
,
colcount
=
self
.
_colcount
,
csr2csc
=
self
.
_csr2csc
,
csc2csr
=
self
.
_csc2csr
,
is_sorted
=
True
)
csc2csr
=
self
.
_csc2csr
,
is_sorted
=
True
)
def
fill_value_
(
self
,
fill_value
:
float
,
dtype
=
Optional
[
torch
.
dtype
]):
def
sparse_sizes
(
self
)
->
List
[
int
]:
value
=
torch
.
empty
(
self
.
_col
.
numel
(),
dtype
,
device
=
self
.
_col
.
device
)
return
self
.
_sparse_sizes
return
self
.
set_value_
(
value
.
fill_
(
fill_value
),
layout
=
'csr'
)
def
fill_value
(
self
,
fill_value
:
float
,
dtype
=
Optional
[
torch
.
dtype
]):
def
sparse_size
(
self
,
dim
:
int
)
->
int
:
value
=
torch
.
empty
(
self
.
_col
.
numel
(),
dtype
,
device
=
self
.
_col
.
device
)
return
self
.
_sparse_sizes
[
dim
]
return
self
.
set_value
(
value
.
fill_
(
fill_value
),
layout
=
'csr'
)
def
sparse_size
(
self
)
->
List
[
int
]:
def
sparse_resize
(
self
,
sparse_sizes
:
List
[
int
]):
return
self
.
_sparse_size
assert
len
(
sparse_sizes
)
==
2
old_sparse_sizes
,
nnz
=
self
.
_sparse_sizes
,
self
.
_col
.
numel
()
def
sparse_resize
(
self
,
sparse_size
:
List
[
int
]):
diff_0
=
sparse_sizes
[
0
]
-
old_sparse_sizes
[
0
]
assert
len
(
sparse_size
)
==
2
old_sparse_size
,
nnz
=
self
.
_sparse_size
,
self
.
_col
.
numel
()
diff_0
=
sparse_size
[
0
]
-
old_sparse_size
[
0
]
rowcount
,
rowptr
=
self
.
_rowcount
,
self
.
_rowptr
rowcount
,
rowptr
=
self
.
_rowcount
,
self
.
_rowptr
if
diff_0
>
0
:
if
diff_0
>
0
:
if
rowptr
is
not
None
:
if
rowptr
is
not
None
:
...
@@ -276,7 +242,7 @@ class SparseStorage(object):
...
@@ -276,7 +242,7 @@ class SparseStorage(object):
if
rowcount
is
not
None
:
if
rowcount
is
not
None
:
rowcount
=
rowcount
[:
-
diff_0
]
rowcount
=
rowcount
[:
-
diff_0
]
diff_1
=
sparse_size
[
1
]
-
old_sparse_size
[
1
]
diff_1
=
sparse_size
s
[
1
]
-
old_sparse_size
s
[
1
]
colcount
,
colptr
=
self
.
_colcount
,
self
.
_colptr
colcount
,
colptr
=
self
.
_colcount
,
self
.
_colptr
if
diff_1
>
0
:
if
diff_1
>
0
:
if
colptr
is
not
None
:
if
colptr
is
not
None
:
...
@@ -290,7 +256,7 @@ class SparseStorage(object):
...
@@ -290,7 +256,7 @@ class SparseStorage(object):
colcount
=
colcount
[:
-
diff_1
]
colcount
=
colcount
[:
-
diff_1
]
return
SparseStorage
(
row
=
self
.
_row
,
rowptr
=
rowptr
,
col
=
self
.
_col
,
return
SparseStorage
(
row
=
self
.
_row
,
rowptr
=
rowptr
,
col
=
self
.
_col
,
value
=
self
.
_value
,
sparse_size
=
sparse_size
,
value
=
self
.
_value
,
sparse_size
s
=
sparse_size
s
,
rowcount
=
rowcount
,
colptr
=
colptr
,
rowcount
=
rowcount
,
colptr
=
colptr
,
colcount
=
colcount
,
csr2csc
=
self
.
_csr2csc
,
colcount
=
colcount
,
csr2csc
=
self
.
_csr2csc
,
csc2csr
=
self
.
_csc2csr
,
is_sorted
=
True
)
csc2csr
=
self
.
_csc2csr
,
is_sorted
=
True
)
...
@@ -319,9 +285,9 @@ class SparseStorage(object):
...
@@ -319,9 +285,9 @@ class SparseStorage(object):
csr2csc
=
self
.
_csr2csc
csr2csc
=
self
.
_csr2csc
if
csr2csc
is
not
None
:
if
csr2csc
is
not
None
:
colptr
=
torch
.
ops
.
torch_sparse_cpu
.
ind2ptr
(
colptr
=
torch
.
ops
.
torch_sparse_cpu
.
ind2ptr
(
self
.
_col
[
csr2csc
],
self
.
_sparse_size
[
1
])
self
.
_col
[
csr2csc
],
self
.
_sparse_size
s
[
1
])
else
:
else
:
colptr
=
self
.
_col
.
new_zeros
(
self
.
_sparse_size
[
1
]
+
1
)
colptr
=
self
.
_col
.
new_zeros
(
self
.
_sparse_size
s
[
1
]
+
1
)
torch
.
cumsum
(
self
.
colcount
(),
dim
=
0
,
out
=
colptr
[
1
:])
torch
.
cumsum
(
self
.
colcount
(),
dim
=
0
,
out
=
colptr
[
1
:])
self
.
_colptr
=
colptr
self
.
_colptr
=
colptr
return
colptr
return
colptr
...
@@ -340,7 +306,7 @@ class SparseStorage(object):
...
@@ -340,7 +306,7 @@ class SparseStorage(object):
else
:
else
:
raise
NotImplementedError
raise
NotImplementedError
# colcount = scatter_add(torch.ones_like(self._col), self._col,
# colcount = scatter_add(torch.ones_like(self._col), self._col,
# dim_size=self._sparse_size[1])
# dim_size=self._sparse_size
s
[1])
self
.
_colcount
=
colcount
self
.
_colcount
=
colcount
return
colcount
return
colcount
...
@@ -352,7 +318,7 @@ class SparseStorage(object):
...
@@ -352,7 +318,7 @@ class SparseStorage(object):
if
csr2csc
is
not
None
:
if
csr2csc
is
not
None
:
return
csr2csc
return
csr2csc
idx
=
self
.
_sparse_size
[
0
]
*
self
.
_col
+
self
.
row
()
idx
=
self
.
_sparse_size
s
[
0
]
*
self
.
_col
+
self
.
row
()
csr2csc
=
idx
.
argsort
()
csr2csc
=
idx
.
argsort
()
self
.
_csr2csc
=
csr2csc
self
.
_csr2csc
=
csr2csc
return
csr2csc
return
csr2csc
...
@@ -371,12 +337,12 @@ class SparseStorage(object):
...
@@ -371,12 +337,12 @@ class SparseStorage(object):
def
is_coalesced
(
self
)
->
bool
:
def
is_coalesced
(
self
)
->
bool
:
idx
=
self
.
_col
.
new_full
((
self
.
_col
.
numel
()
+
1
,
),
-
1
)
idx
=
self
.
_col
.
new_full
((
self
.
_col
.
numel
()
+
1
,
),
-
1
)
idx
[
1
:]
=
self
.
_sparse_size
[
1
]
*
self
.
row
()
+
self
.
_col
idx
[
1
:]
=
self
.
_sparse_size
s
[
1
]
*
self
.
row
()
+
self
.
_col
return
bool
((
idx
[
1
:]
>
idx
[:
-
1
]).
all
())
return
bool
((
idx
[
1
:]
>
idx
[:
-
1
]).
all
())
def
coalesce
(
self
,
reduce
:
str
=
"add"
):
def
coalesce
(
self
,
reduce
:
str
=
"add"
):
idx
=
self
.
_col
.
new_full
((
self
.
_col
.
numel
()
+
1
,
),
-
1
)
idx
=
self
.
_col
.
new_full
((
self
.
_col
.
numel
()
+
1
,
),
-
1
)
idx
[
1
:]
=
self
.
_sparse_size
[
1
]
*
self
.
row
()
+
self
.
_col
idx
[
1
:]
=
self
.
_sparse_size
s
[
1
]
*
self
.
row
()
+
self
.
_col
mask
=
idx
[
1
:]
>
idx
[:
-
1
]
mask
=
idx
[
1
:]
>
idx
[:
-
1
]
if
mask
.
all
():
# Skip if indices are already coalesced.
if
mask
.
all
():
# Skip if indices are already coalesced.
...
@@ -394,7 +360,7 @@ class SparseStorage(object):
...
@@ -394,7 +360,7 @@ class SparseStorage(object):
value
=
value
[
0
]
if
isinstance
(
value
,
tuple
)
else
value
value
=
value
[
0
]
if
isinstance
(
value
,
tuple
)
else
value
return
SparseStorage
(
row
=
row
,
rowptr
=
None
,
col
=
col
,
value
=
value
,
return
SparseStorage
(
row
=
row
,
rowptr
=
None
,
col
=
col
,
value
=
value
,
sparse_size
=
self
.
_sparse_size
,
rowcount
=
None
,
sparse_size
s
=
self
.
_sparse_size
s
,
rowcount
=
None
,
colptr
=
None
,
colcount
=
None
,
csr2csc
=
None
,
colptr
=
None
,
colcount
=
None
,
csr2csc
=
None
,
csc2csr
=
None
,
is_sorted
=
True
)
csc2csr
=
None
,
is_sorted
=
True
)
...
@@ -418,7 +384,8 @@ class SparseStorage(object):
...
@@ -418,7 +384,8 @@ class SparseStorage(object):
def
copy
(
self
):
def
copy
(
self
):
return
SparseStorage
(
row
=
self
.
_row
,
rowptr
=
self
.
_rowptr
,
col
=
self
.
_col
,
return
SparseStorage
(
row
=
self
.
_row
,
rowptr
=
self
.
_rowptr
,
col
=
self
.
_col
,
value
=
self
.
_value
,
sparse_size
=
self
.
_sparse_size
,
value
=
self
.
_value
,
sparse_sizes
=
self
.
_sparse_sizes
,
rowcount
=
self
.
_rowcount
,
colptr
=
self
.
_colptr
,
rowcount
=
self
.
_rowcount
,
colptr
=
self
.
_colptr
,
colcount
=
self
.
_colcount
,
csr2csc
=
self
.
_csr2csc
,
colcount
=
self
.
_colcount
,
csr2csc
=
self
.
_csr2csc
,
csc2csr
=
self
.
_csc2csr
,
is_sorted
=
True
)
csc2csr
=
self
.
_csc2csr
,
is_sorted
=
True
)
...
@@ -430,6 +397,7 @@ class SparseStorage(object):
...
@@ -430,6 +397,7 @@ class SparseStorage(object):
rowptr
=
self
.
_rowptr
rowptr
=
self
.
_rowptr
if
rowptr
is
not
None
:
if
rowptr
is
not
None
:
rowptr
=
rowptr
.
clone
()
rowptr
=
rowptr
.
clone
()
col
=
self
.
_col
.
clone
()
value
=
self
.
_value
value
=
self
.
_value
if
value
is
not
None
:
if
value
is
not
None
:
value
=
value
.
clone
()
value
=
value
.
clone
()
...
@@ -448,8 +416,178 @@ class SparseStorage(object):
...
@@ -448,8 +416,178 @@ class SparseStorage(object):
csc2csr
=
self
.
_csc2csr
csc2csr
=
self
.
_csc2csr
if
csc2csr
is
not
None
:
if
csc2csr
is
not
None
:
csc2csr
=
csc2csr
.
clone
()
csc2csr
=
csc2csr
.
clone
()
return
SparseStorage
(
row
=
row
,
rowptr
=
rowptr
,
col
=
self
.
_col
.
clone
(),
return
SparseStorage
(
row
=
row
,
rowptr
=
rowptr
,
col
=
col
,
value
=
value
,
value
=
value
,
sparse_size
=
self
.
_sparse_size
,
sparse_sizes
=
self
.
_sparse_sizes
,
rowcount
=
rowcount
,
colptr
=
colptr
,
colcount
=
colcount
,
csr2csc
=
csr2csc
,
csc2csr
=
csc2csr
,
is_sorted
=
True
)
def
type_as
(
self
,
tensor
=
torch
.
Tensor
):
value
=
self
.
_value
if
value
is
not
None
:
if
tensor
.
dtype
==
value
.
dtype
:
return
self
else
:
return
self
.
set_value
(
value
.
type_as
(
tensor
),
layout
=
'coo'
)
else
:
return
self
def
device_as
(
self
,
tensor
:
torch
.
Tensor
,
non_blocking
:
bool
=
False
):
if
tensor
.
device
==
self
.
_col
.
device
:
return
self
row
=
self
.
_row
if
row
is
not
None
:
row
=
row
.
to
(
tensor
.
device
,
non_blocking
=
non_blocking
)
rowptr
=
self
.
_rowptr
if
rowptr
is
not
None
:
rowptr
=
rowptr
.
to
(
tensor
.
device
,
non_blocking
=
non_blocking
)
col
=
self
.
_col
.
to
(
tensor
.
device
,
non_blocking
=
non_blocking
)
value
=
self
.
_value
if
value
is
not
None
:
value
=
value
.
to
(
tensor
.
device
,
non_blocking
=
non_blocking
)
rowcount
=
self
.
_rowcount
if
rowcount
is
not
None
:
rowcount
=
rowcount
.
to
(
tensor
.
device
,
non_blocking
=
non_blocking
)
colptr
=
self
.
_colptr
if
colptr
is
not
None
:
colptr
=
colptr
.
to
(
tensor
.
device
,
non_blocking
=
non_blocking
)
colcount
=
self
.
_colcount
if
colcount
is
not
None
:
colcount
=
colcount
.
to
(
tensor
.
device
,
non_blocking
=
non_blocking
)
csr2csc
=
self
.
_csr2csc
if
csr2csc
is
not
None
:
csr2csc
=
csr2csc
.
to
(
tensor
.
device
,
non_blocking
=
non_blocking
)
csc2csr
=
self
.
_csc2csr
if
csc2csr
is
not
None
:
csc2csr
=
csc2csr
.
to
(
tensor
.
device
,
non_blocking
=
non_blocking
)
return
SparseStorage
(
row
=
row
,
rowptr
=
rowptr
,
col
=
col
,
value
=
value
,
sparse_sizes
=
self
.
_sparse_sizes
,
rowcount
=
rowcount
,
colptr
=
colptr
,
colcount
=
colcount
,
csr2csc
=
csr2csc
,
csc2csr
=
csc2csr
,
is_sorted
=
True
)
def
pin_memory
(
self
):
row
=
self
.
_row
if
row
is
not
None
:
row
=
row
.
pin_memory
()
rowptr
=
self
.
_rowptr
if
rowptr
is
not
None
:
rowptr
=
rowptr
.
pin_memory
()
col
=
self
.
_col
.
pin_memory
()
value
=
self
.
_value
if
value
is
not
None
:
value
=
value
.
pin_memory
()
rowcount
=
self
.
_rowcount
if
rowcount
is
not
None
:
rowcount
=
rowcount
.
pin_memory
()
colptr
=
self
.
_colptr
if
colptr
is
not
None
:
colptr
=
colptr
.
pin_memory
()
colcount
=
self
.
_colcount
if
colcount
is
not
None
:
colcount
=
colcount
.
pin_memory
()
csr2csc
=
self
.
_csr2csc
if
csr2csc
is
not
None
:
csr2csc
=
csr2csc
.
pin_memory
()
csc2csr
=
self
.
_csc2csr
if
csc2csr
is
not
None
:
csc2csr
=
csc2csr
.
pin_memory
()
return
SparseStorage
(
row
=
row
,
rowptr
=
rowptr
,
col
=
col
,
value
=
value
,
sparse_sizes
=
self
.
_sparse_sizes
,
rowcount
=
rowcount
,
colptr
=
colptr
,
rowcount
=
rowcount
,
colptr
=
colptr
,
colcount
=
colcount
,
csr2csc
=
csr2csc
,
colcount
=
colcount
,
csr2csc
=
csr2csc
,
csc2csr
=
csc2csr
,
is_sorted
=
True
)
csc2csr
=
csc2csr
,
is_sorted
=
True
)
def
is_pinned
(
self
)
->
bool
:
is_pinned
=
True
row
=
self
.
_row
if
row
is
not
None
:
is_pinned
=
is_pinned
and
row
.
is_pinned
()
rowptr
=
self
.
_rowptr
if
rowptr
is
not
None
:
is_pinned
=
is_pinned
and
rowptr
.
is_pinned
()
is_pinned
=
self
.
_col
.
is_pinned
()
value
=
self
.
_value
if
value
is
not
None
:
is_pinned
=
is_pinned
and
value
.
is_pinned
()
rowcount
=
self
.
_rowcount
if
rowcount
is
not
None
:
is_pinned
=
is_pinned
and
rowcount
.
is_pinned
()
colptr
=
self
.
_colptr
if
colptr
is
not
None
:
is_pinned
=
is_pinned
and
colptr
.
is_pinned
()
colcount
=
self
.
_colcount
if
colcount
is
not
None
:
is_pinned
=
is_pinned
and
colcount
.
is_pinned
()
csr2csc
=
self
.
_csr2csc
if
csr2csc
is
not
None
:
is_pinned
=
is_pinned
and
csr2csc
.
is_pinned
()
csc2csr
=
self
.
_csc2csr
if
csc2csr
is
not
None
:
is_pinned
=
is_pinned
and
csc2csr
.
is_pinned
()
return
is_pinned
@
torch
.
jit
.
ignore
def
share_memory_
(
self
)
->
SparseStorage
:
row
=
self
.
_row
if
row
is
not
None
:
row
.
share_memory_
()
rowptr
=
self
.
_rowptr
if
rowptr
is
not
None
:
rowptr
.
share_memory_
()
self
.
_col
.
share_memory_
()
value
=
self
.
_value
if
value
is
not
None
:
value
.
share_memory_
()
rowcount
=
self
.
_rowcount
if
rowcount
is
not
None
:
rowcount
.
share_memory_
()
colptr
=
self
.
_colptr
if
colptr
is
not
None
:
colptr
.
share_memory_
()
colcount
=
self
.
_colcount
if
colcount
is
not
None
:
colcount
.
share_memory_
()
csr2csc
=
self
.
_csr2csc
if
csr2csc
is
not
None
:
csr2csc
.
share_memory_
()
csc2csr
=
self
.
_csc2csr
if
csc2csr
is
not
None
:
csc2csr
.
share_memory_
()
@
torch
.
jit
.
ignore
def
is_shared
(
self
)
->
bool
:
is_shared
=
True
row
=
self
.
_row
if
row
is
not
None
:
is_shared
=
is_shared
and
row
.
is_shared
()
rowptr
=
self
.
_rowptr
if
rowptr
is
not
None
:
is_shared
=
is_shared
and
rowptr
.
is_shared
()
is_shared
=
is_shared
and
self
.
_col
.
is_shared
()
value
=
self
.
_value
if
value
is
not
None
:
is_shared
=
is_shared
and
value
.
is_shared
()
rowcount
=
self
.
_rowcount
if
rowcount
is
not
None
:
is_shared
=
is_shared
and
rowcount
.
is_shared
()
colptr
=
self
.
_colptr
if
colptr
is
not
None
:
is_shared
=
is_shared
and
colptr
.
is_shared
()
colcount
=
self
.
_colcount
if
colcount
is
not
None
:
is_shared
=
is_shared
and
colcount
.
is_shared
()
csr2csc
=
self
.
_csr2csc
if
csr2csc
is
not
None
:
is_shared
=
is_shared
and
csr2csc
.
is_shared
()
csc2csr
=
self
.
_csc2csr
if
csc2csr
is
not
None
:
is_shared
=
is_shared
and
csc2csr
.
is_shared
()
return
is_shared
SparseStorage
.
share_memory_
=
share_memory_
SparseStorage
.
is_shared
=
is_shared
torch_sparse/tensor.py
View file @
f87afd09
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment