Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
f87afd09
Commit
f87afd09
authored
Jan 28, 2020
by
rusty1s
Browse files
tensor and storage mostly jittable
parent
631eee37
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
657 additions
and
413 deletions
+657
-413
test/test_jit.py
test/test_jit.py
+8
-4
torch_sparse/storage.py
torch_sparse/storage.py
+208
-70
torch_sparse/tensor.py
torch_sparse/tensor.py
+441
-339
No files found.
test/test_jit.py
View file @
f87afd09
...
...
@@ -30,9 +30,10 @@ class MyCell(torch.nn.Module):
self
.
linear
=
torch
.
nn
.
Linear
(
2
,
4
)
# def forward(self, x: torch.Tensor, ptr: torch.Tensor) -> torch.Tensor:
def
forward
(
self
,
x
:
torch
.
Tensor
,
adj
:
SparseStorage
)
->
torch
.
Tensor
:
out
,
_
=
torch
.
ops
.
torch_sparse_cpu
.
spmm
(
adj
.
rowptr
(),
adj
.
col
(),
None
,
x
,
'sum'
)
def
forward
(
self
,
x
:
torch
.
Tensor
,
adj
:
SparseTensor
)
->
torch
.
Tensor
:
out
,
_
=
torch
.
ops
.
torch_sparse_cpu
.
spmm
(
adj
.
storage
.
rowptr
(),
adj
.
storage
.
col
(),
None
,
x
,
'sum'
)
return
out
...
...
@@ -67,7 +68,10 @@ def test_jit():
rowptr
=
torch
.
tensor
([
0
,
3
,
6
,
9
])
col
=
torch
.
tensor
([
0
,
1
,
2
,
0
,
1
,
2
,
0
,
1
,
2
])
adj
=
SparseStorage
(
rowptr
=
rowptr
,
col
=
col
)
adj
=
SparseTensor
(
rowptr
=
rowptr
,
col
=
col
)
scipy
=
adj
.
to_scipy
(
layout
=
'csr'
)
mat
=
SparseTensor
.
from_scipy
(
scipy
)
mat
.
fill_value_
(
2.3
)
# adj = {'rowptr': mat.storage.rowptr, 'col': mat.storage.col}
# foo = Foo(mat.storage.rowptr, mat.storage.col)
...
...
torch_sparse/storage.py
View file @
f87afd09
import
warnings
from
typing
import
Optional
,
List
,
Dict
,
Union
,
Any
from
typing
import
Optional
,
List
import
torch
from
torch_scatter
import
segment_csr
,
scatter_add
from
torch_sparse.utils
import
Final
,
is_scalar
# __cache__ = {'enabled': True}
# def is_cache_enabled():
# return __cache__['enabled']
# def set_cache_enabled(mode):
# __cache__['enabled'] = mode
# class no_cache(object):
# def __enter__(self):
# self.prev = is_cache_enabled()
# set_cache_enabled(False)
# def __exit__(self, *args):
# set_cache_enabled(self.prev)
# return False
# def __call__(self, func):
# def decorate_no_cache(*args, **kwargs):
# with self:
# return func(*args, **kwargs)
# return decorate_no_cache
def
optional
(
func
,
src
):
return
func
(
src
)
if
src
is
not
None
else
src
from
torch_sparse.utils
import
Final
layouts
:
Final
[
List
[
str
]]
=
[
'coo'
,
'csr'
,
'csc'
]
...
...
@@ -52,7 +23,7 @@ class SparseStorage(object):
_rowptr
:
Optional
[
torch
.
Tensor
]
_col
:
torch
.
Tensor
_value
:
Optional
[
torch
.
Tensor
]
_sparse_size
:
List
[
int
]
_sparse_size
s
:
List
[
int
]
_rowcount
:
Optional
[
torch
.
Tensor
]
_colptr
:
Optional
[
torch
.
Tensor
]
_colcount
:
Optional
[
torch
.
Tensor
]
...
...
@@ -63,7 +34,7 @@ class SparseStorage(object):
rowptr
:
Optional
[
torch
.
Tensor
]
=
None
,
col
:
Optional
[
torch
.
Tensor
]
=
None
,
value
:
Optional
[
torch
.
Tensor
]
=
None
,
sparse_size
:
Optional
[
List
[
int
]]
=
None
,
sparse_size
s
:
Optional
[
List
[
int
]]
=
None
,
rowcount
:
Optional
[
torch
.
Tensor
]
=
None
,
colptr
:
Optional
[
torch
.
Tensor
]
=
None
,
colcount
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
@@ -77,7 +48,7 @@ class SparseStorage(object):
assert
col
.
dim
()
==
1
col
=
col
.
contiguous
()
if
sparse_size
is
None
:
if
sparse_size
s
is
None
:
if
rowptr
is
not
None
:
M
=
rowptr
.
numel
()
-
1
elif
row
is
not
None
:
...
...
@@ -85,9 +56,9 @@ class SparseStorage(object):
else
:
raise
ValueError
N
=
col
.
max
().
item
()
+
1
sparse_size
=
torch
.
Size
([
int
(
M
),
int
(
N
)])
sparse_size
s
=
torch
.
Size
([
int
(
M
),
int
(
N
)])
else
:
assert
len
(
sparse_size
)
==
2
assert
len
(
sparse_size
s
)
==
2
if
row
is
not
None
:
assert
row
.
dtype
==
torch
.
long
...
...
@@ -100,7 +71,7 @@ class SparseStorage(object):
assert
rowptr
.
dtype
==
torch
.
long
assert
rowptr
.
device
==
col
.
device
assert
rowptr
.
dim
()
==
1
assert
rowptr
.
numel
()
-
1
==
sparse_size
[
0
]
assert
rowptr
.
numel
()
-
1
==
sparse_size
s
[
0
]
rowptr
=
rowptr
.
contiguous
()
if
value
is
not
None
:
...
...
@@ -112,21 +83,21 @@ class SparseStorage(object):
assert
rowcount
.
dtype
==
torch
.
long
assert
rowcount
.
device
==
col
.
device
assert
rowcount
.
dim
()
==
1
assert
rowcount
.
numel
()
==
sparse_size
[
0
]
assert
rowcount
.
numel
()
==
sparse_size
s
[
0
]
rowcount
=
rowcount
.
contiguous
()
if
colptr
is
not
None
:
assert
colptr
.
dtype
==
torch
.
long
assert
colptr
.
device
==
col
.
device
assert
colptr
.
dim
()
==
1
assert
colptr
.
numel
()
-
1
==
sparse_size
[
1
]
assert
colptr
.
numel
()
-
1
==
sparse_size
s
[
1
]
colptr
=
colptr
.
contiguous
()
if
colcount
is
not
None
:
assert
colcount
.
dtype
==
torch
.
long
assert
colcount
.
device
==
col
.
device
assert
colcount
.
dim
()
==
1
assert
colcount
.
numel
()
==
sparse_size
[
1
]
assert
colcount
.
numel
()
==
sparse_size
s
[
1
]
colcount
=
colcount
.
contiguous
()
if
csr2csc
is
not
None
:
...
...
@@ -147,7 +118,7 @@ class SparseStorage(object):
self
.
_rowptr
=
rowptr
self
.
_col
=
col
self
.
_value
=
value
self
.
_sparse_size
=
sparse_size
self
.
_sparse_size
s
=
sparse_size
s
self
.
_rowcount
=
rowcount
self
.
_colptr
=
colptr
self
.
_colcount
=
colcount
...
...
@@ -156,7 +127,7 @@ class SparseStorage(object):
if
not
is_sorted
:
idx
=
col
.
new_zeros
(
col
.
numel
()
+
1
)
idx
[
1
:]
=
sparse_size
[
1
]
*
self
.
row
()
+
col
idx
[
1
:]
=
sparse_size
s
[
1
]
*
self
.
row
()
+
col
if
(
idx
[
1
:]
<
idx
[:
-
1
]).
any
():
perm
=
idx
[
1
:].
argsort
()
self
.
_row
=
self
.
row
()[
perm
]
...
...
@@ -203,10 +174,10 @@ class SparseStorage(object):
if
row
is
not
None
:
if
row
.
is_cuda
:
rowptr
=
torch
.
ops
.
torch_sparse_cuda
.
ind2ptr
(
row
,
self
.
_sparse_size
[
0
])
row
,
self
.
_sparse_size
s
[
0
])
else
:
rowptr
=
torch
.
ops
.
torch_sparse_cpu
.
ind2ptr
(
row
,
self
.
_sparse_size
[
0
])
row
,
self
.
_sparse_size
s
[
0
])
self
.
_rowptr
=
rowptr
return
rowptr
...
...
@@ -243,27 +214,22 @@ class SparseStorage(object):
assert
value
.
size
(
0
)
==
self
.
_col
.
numel
()
return
SparseStorage
(
row
=
self
.
_row
,
rowptr
=
self
.
_rowptr
,
col
=
self
.
_col
,
value
=
value
,
sparse_size
=
self
.
_sparse_size
,
value
=
value
,
sparse_size
s
=
self
.
_sparse_size
s
,
rowcount
=
self
.
_rowcount
,
colptr
=
self
.
_colptr
,
colcount
=
self
.
_colcount
,
csr2csc
=
self
.
_csr2csc
,
csc2csr
=
self
.
_csc2csr
,
is_sorted
=
True
)
def
fill_value_
(
self
,
fill_value
:
float
,
dtype
=
Optional
[
torch
.
dtype
]):
value
=
torch
.
empty
(
self
.
_col
.
numel
(),
dtype
,
device
=
self
.
_col
.
device
)
return
self
.
set_value_
(
value
.
fill_
(
fill_value
),
layout
=
'csr'
)
def
sparse_sizes
(
self
)
->
List
[
int
]:
return
self
.
_sparse_sizes
def
fill_value
(
self
,
fill_value
:
float
,
dtype
=
Optional
[
torch
.
dtype
]):
value
=
torch
.
empty
(
self
.
_col
.
numel
(),
dtype
,
device
=
self
.
_col
.
device
)
return
self
.
set_value
(
value
.
fill_
(
fill_value
),
layout
=
'csr'
)
def
sparse_size
(
self
,
dim
:
int
)
->
int
:
return
self
.
_sparse_sizes
[
dim
]
def
sparse_size
(
self
)
->
List
[
int
]:
return
self
.
_sparse_size
def
sparse_resize
(
self
,
sparse_sizes
:
List
[
int
]):
assert
len
(
sparse_sizes
)
==
2
old_sparse_sizes
,
nnz
=
self
.
_sparse_sizes
,
self
.
_col
.
numel
()
def
sparse_resize
(
self
,
sparse_size
:
List
[
int
]):
assert
len
(
sparse_size
)
==
2
old_sparse_size
,
nnz
=
self
.
_sparse_size
,
self
.
_col
.
numel
()
diff_0
=
sparse_size
[
0
]
-
old_sparse_size
[
0
]
diff_0
=
sparse_sizes
[
0
]
-
old_sparse_sizes
[
0
]
rowcount
,
rowptr
=
self
.
_rowcount
,
self
.
_rowptr
if
diff_0
>
0
:
if
rowptr
is
not
None
:
...
...
@@ -276,7 +242,7 @@ class SparseStorage(object):
if
rowcount
is
not
None
:
rowcount
=
rowcount
[:
-
diff_0
]
diff_1
=
sparse_size
[
1
]
-
old_sparse_size
[
1
]
diff_1
=
sparse_size
s
[
1
]
-
old_sparse_size
s
[
1
]
colcount
,
colptr
=
self
.
_colcount
,
self
.
_colptr
if
diff_1
>
0
:
if
colptr
is
not
None
:
...
...
@@ -290,7 +256,7 @@ class SparseStorage(object):
colcount
=
colcount
[:
-
diff_1
]
return
SparseStorage
(
row
=
self
.
_row
,
rowptr
=
rowptr
,
col
=
self
.
_col
,
value
=
self
.
_value
,
sparse_size
=
sparse_size
,
value
=
self
.
_value
,
sparse_size
s
=
sparse_size
s
,
rowcount
=
rowcount
,
colptr
=
colptr
,
colcount
=
colcount
,
csr2csc
=
self
.
_csr2csc
,
csc2csr
=
self
.
_csc2csr
,
is_sorted
=
True
)
...
...
@@ -319,9 +285,9 @@ class SparseStorage(object):
csr2csc
=
self
.
_csr2csc
if
csr2csc
is
not
None
:
colptr
=
torch
.
ops
.
torch_sparse_cpu
.
ind2ptr
(
self
.
_col
[
csr2csc
],
self
.
_sparse_size
[
1
])
self
.
_col
[
csr2csc
],
self
.
_sparse_size
s
[
1
])
else
:
colptr
=
self
.
_col
.
new_zeros
(
self
.
_sparse_size
[
1
]
+
1
)
colptr
=
self
.
_col
.
new_zeros
(
self
.
_sparse_size
s
[
1
]
+
1
)
torch
.
cumsum
(
self
.
colcount
(),
dim
=
0
,
out
=
colptr
[
1
:])
self
.
_colptr
=
colptr
return
colptr
...
...
@@ -340,7 +306,7 @@ class SparseStorage(object):
else
:
raise
NotImplementedError
# colcount = scatter_add(torch.ones_like(self._col), self._col,
# dim_size=self._sparse_size[1])
# dim_size=self._sparse_size
s
[1])
self
.
_colcount
=
colcount
return
colcount
...
...
@@ -352,7 +318,7 @@ class SparseStorage(object):
if
csr2csc
is
not
None
:
return
csr2csc
idx
=
self
.
_sparse_size
[
0
]
*
self
.
_col
+
self
.
row
()
idx
=
self
.
_sparse_size
s
[
0
]
*
self
.
_col
+
self
.
row
()
csr2csc
=
idx
.
argsort
()
self
.
_csr2csc
=
csr2csc
return
csr2csc
...
...
@@ -371,12 +337,12 @@ class SparseStorage(object):
def
is_coalesced
(
self
)
->
bool
:
idx
=
self
.
_col
.
new_full
((
self
.
_col
.
numel
()
+
1
,
),
-
1
)
idx
[
1
:]
=
self
.
_sparse_size
[
1
]
*
self
.
row
()
+
self
.
_col
idx
[
1
:]
=
self
.
_sparse_size
s
[
1
]
*
self
.
row
()
+
self
.
_col
return
bool
((
idx
[
1
:]
>
idx
[:
-
1
]).
all
())
def
coalesce
(
self
,
reduce
:
str
=
"add"
):
idx
=
self
.
_col
.
new_full
((
self
.
_col
.
numel
()
+
1
,
),
-
1
)
idx
[
1
:]
=
self
.
_sparse_size
[
1
]
*
self
.
row
()
+
self
.
_col
idx
[
1
:]
=
self
.
_sparse_size
s
[
1
]
*
self
.
row
()
+
self
.
_col
mask
=
idx
[
1
:]
>
idx
[:
-
1
]
if
mask
.
all
():
# Skip if indices are already coalesced.
...
...
@@ -394,7 +360,7 @@ class SparseStorage(object):
value
=
value
[
0
]
if
isinstance
(
value
,
tuple
)
else
value
return
SparseStorage
(
row
=
row
,
rowptr
=
None
,
col
=
col
,
value
=
value
,
sparse_size
=
self
.
_sparse_size
,
rowcount
=
None
,
sparse_size
s
=
self
.
_sparse_size
s
,
rowcount
=
None
,
colptr
=
None
,
colcount
=
None
,
csr2csc
=
None
,
csc2csr
=
None
,
is_sorted
=
True
)
...
...
@@ -418,7 +384,8 @@ class SparseStorage(object):
def
copy
(
self
):
return
SparseStorage
(
row
=
self
.
_row
,
rowptr
=
self
.
_rowptr
,
col
=
self
.
_col
,
value
=
self
.
_value
,
sparse_size
=
self
.
_sparse_size
,
value
=
self
.
_value
,
sparse_sizes
=
self
.
_sparse_sizes
,
rowcount
=
self
.
_rowcount
,
colptr
=
self
.
_colptr
,
colcount
=
self
.
_colcount
,
csr2csc
=
self
.
_csr2csc
,
csc2csr
=
self
.
_csc2csr
,
is_sorted
=
True
)
...
...
@@ -430,6 +397,7 @@ class SparseStorage(object):
rowptr
=
self
.
_rowptr
if
rowptr
is
not
None
:
rowptr
=
rowptr
.
clone
()
col
=
self
.
_col
.
clone
()
value
=
self
.
_value
if
value
is
not
None
:
value
=
value
.
clone
()
...
...
@@ -448,8 +416,178 @@ class SparseStorage(object):
csc2csr
=
self
.
_csc2csr
if
csc2csr
is
not
None
:
csc2csr
=
csc2csr
.
clone
()
return
SparseStorage
(
row
=
row
,
rowptr
=
rowptr
,
col
=
self
.
_col
.
clone
(),
value
=
value
,
sparse_size
=
self
.
_sparse_size
,
return
SparseStorage
(
row
=
row
,
rowptr
=
rowptr
,
col
=
col
,
value
=
value
,
sparse_sizes
=
self
.
_sparse_sizes
,
rowcount
=
rowcount
,
colptr
=
colptr
,
colcount
=
colcount
,
csr2csc
=
csr2csc
,
csc2csr
=
csc2csr
,
is_sorted
=
True
)
def
type_as
(
self
,
tensor
=
torch
.
Tensor
):
value
=
self
.
_value
if
value
is
not
None
:
if
tensor
.
dtype
==
value
.
dtype
:
return
self
else
:
return
self
.
set_value
(
value
.
type_as
(
tensor
),
layout
=
'coo'
)
else
:
return
self
def
device_as
(
self
,
tensor
:
torch
.
Tensor
,
non_blocking
:
bool
=
False
):
if
tensor
.
device
==
self
.
_col
.
device
:
return
self
row
=
self
.
_row
if
row
is
not
None
:
row
=
row
.
to
(
tensor
.
device
,
non_blocking
=
non_blocking
)
rowptr
=
self
.
_rowptr
if
rowptr
is
not
None
:
rowptr
=
rowptr
.
to
(
tensor
.
device
,
non_blocking
=
non_blocking
)
col
=
self
.
_col
.
to
(
tensor
.
device
,
non_blocking
=
non_blocking
)
value
=
self
.
_value
if
value
is
not
None
:
value
=
value
.
to
(
tensor
.
device
,
non_blocking
=
non_blocking
)
rowcount
=
self
.
_rowcount
if
rowcount
is
not
None
:
rowcount
=
rowcount
.
to
(
tensor
.
device
,
non_blocking
=
non_blocking
)
colptr
=
self
.
_colptr
if
colptr
is
not
None
:
colptr
=
colptr
.
to
(
tensor
.
device
,
non_blocking
=
non_blocking
)
colcount
=
self
.
_colcount
if
colcount
is
not
None
:
colcount
=
colcount
.
to
(
tensor
.
device
,
non_blocking
=
non_blocking
)
csr2csc
=
self
.
_csr2csc
if
csr2csc
is
not
None
:
csr2csc
=
csr2csc
.
to
(
tensor
.
device
,
non_blocking
=
non_blocking
)
csc2csr
=
self
.
_csc2csr
if
csc2csr
is
not
None
:
csc2csr
=
csc2csr
.
to
(
tensor
.
device
,
non_blocking
=
non_blocking
)
return
SparseStorage
(
row
=
row
,
rowptr
=
rowptr
,
col
=
col
,
value
=
value
,
sparse_sizes
=
self
.
_sparse_sizes
,
rowcount
=
rowcount
,
colptr
=
colptr
,
colcount
=
colcount
,
csr2csc
=
csr2csc
,
csc2csr
=
csc2csr
,
is_sorted
=
True
)
def
pin_memory
(
self
):
row
=
self
.
_row
if
row
is
not
None
:
row
=
row
.
pin_memory
()
rowptr
=
self
.
_rowptr
if
rowptr
is
not
None
:
rowptr
=
rowptr
.
pin_memory
()
col
=
self
.
_col
.
pin_memory
()
value
=
self
.
_value
if
value
is
not
None
:
value
=
value
.
pin_memory
()
rowcount
=
self
.
_rowcount
if
rowcount
is
not
None
:
rowcount
=
rowcount
.
pin_memory
()
colptr
=
self
.
_colptr
if
colptr
is
not
None
:
colptr
=
colptr
.
pin_memory
()
colcount
=
self
.
_colcount
if
colcount
is
not
None
:
colcount
=
colcount
.
pin_memory
()
csr2csc
=
self
.
_csr2csc
if
csr2csc
is
not
None
:
csr2csc
=
csr2csc
.
pin_memory
()
csc2csr
=
self
.
_csc2csr
if
csc2csr
is
not
None
:
csc2csr
=
csc2csr
.
pin_memory
()
return
SparseStorage
(
row
=
row
,
rowptr
=
rowptr
,
col
=
col
,
value
=
value
,
sparse_sizes
=
self
.
_sparse_sizes
,
rowcount
=
rowcount
,
colptr
=
colptr
,
colcount
=
colcount
,
csr2csc
=
csr2csc
,
csc2csr
=
csc2csr
,
is_sorted
=
True
)
def
is_pinned
(
self
)
->
bool
:
is_pinned
=
True
row
=
self
.
_row
if
row
is
not
None
:
is_pinned
=
is_pinned
and
row
.
is_pinned
()
rowptr
=
self
.
_rowptr
if
rowptr
is
not
None
:
is_pinned
=
is_pinned
and
rowptr
.
is_pinned
()
is_pinned
=
self
.
_col
.
is_pinned
()
value
=
self
.
_value
if
value
is
not
None
:
is_pinned
=
is_pinned
and
value
.
is_pinned
()
rowcount
=
self
.
_rowcount
if
rowcount
is
not
None
:
is_pinned
=
is_pinned
and
rowcount
.
is_pinned
()
colptr
=
self
.
_colptr
if
colptr
is
not
None
:
is_pinned
=
is_pinned
and
colptr
.
is_pinned
()
colcount
=
self
.
_colcount
if
colcount
is
not
None
:
is_pinned
=
is_pinned
and
colcount
.
is_pinned
()
csr2csc
=
self
.
_csr2csc
if
csr2csc
is
not
None
:
is_pinned
=
is_pinned
and
csr2csc
.
is_pinned
()
csc2csr
=
self
.
_csc2csr
if
csc2csr
is
not
None
:
is_pinned
=
is_pinned
and
csc2csr
.
is_pinned
()
return
is_pinned
@
torch
.
jit
.
ignore
def
share_memory_
(
self
)
->
SparseStorage
:
row
=
self
.
_row
if
row
is
not
None
:
row
.
share_memory_
()
rowptr
=
self
.
_rowptr
if
rowptr
is
not
None
:
rowptr
.
share_memory_
()
self
.
_col
.
share_memory_
()
value
=
self
.
_value
if
value
is
not
None
:
value
.
share_memory_
()
rowcount
=
self
.
_rowcount
if
rowcount
is
not
None
:
rowcount
.
share_memory_
()
colptr
=
self
.
_colptr
if
colptr
is
not
None
:
colptr
.
share_memory_
()
colcount
=
self
.
_colcount
if
colcount
is
not
None
:
colcount
.
share_memory_
()
csr2csc
=
self
.
_csr2csc
if
csr2csc
is
not
None
:
csr2csc
.
share_memory_
()
csc2csr
=
self
.
_csc2csr
if
csc2csr
is
not
None
:
csc2csr
.
share_memory_
()
@
torch
.
jit
.
ignore
def
is_shared
(
self
)
->
bool
:
is_shared
=
True
row
=
self
.
_row
if
row
is
not
None
:
is_shared
=
is_shared
and
row
.
is_shared
()
rowptr
=
self
.
_rowptr
if
rowptr
is
not
None
:
is_shared
=
is_shared
and
rowptr
.
is_shared
()
is_shared
=
is_shared
and
self
.
_col
.
is_shared
()
value
=
self
.
_value
if
value
is
not
None
:
is_shared
=
is_shared
and
value
.
is_shared
()
rowcount
=
self
.
_rowcount
if
rowcount
is
not
None
:
is_shared
=
is_shared
and
rowcount
.
is_shared
()
colptr
=
self
.
_colptr
if
colptr
is
not
None
:
is_shared
=
is_shared
and
colptr
.
is_shared
()
colcount
=
self
.
_colcount
if
colcount
is
not
None
:
is_shared
=
is_shared
and
colcount
.
is_shared
()
csr2csc
=
self
.
_csr2csc
if
csr2csc
is
not
None
:
is_shared
=
is_shared
and
csr2csc
.
is_shared
()
csc2csr
=
self
.
_csc2csr
if
csc2csr
is
not
None
:
is_shared
=
is_shared
and
csc2csr
.
is_shared
()
return
is_shared
SparseStorage
.
share_memory_
=
share_memory_
SparseStorage
.
is_shared
=
is_shared
torch_sparse/tensor.py
View file @
f87afd09
from
textwrap
import
indent
# from textwrap import indent
from
typing
import
Optional
,
List
,
Tuple
,
Union
import
torch
import
scipy.sparse
from
torch_sparse.storage
import
SparseStorage
,
get_layout
from
torch_sparse.transpose
import
t
from
torch_sparse.narrow
import
narrow
from
torch_sparse.select
import
select
from
torch_sparse.index_select
import
index_select
,
index_select_nnz
from
torch_sparse.masked_select
import
masked_select
,
masked_select_nnz
import
torch_sparse.reduce
from
torch_sparse.diag
import
remove_diag
,
set_diag
from
torch_sparse.matmul
import
matmul
from
torch_sparse.add
import
add
,
add_
,
add_nnz
,
add_nnz_
from
torch_sparse.mul
import
mul
,
mul_
,
mul_nnz
,
mul_nnz_
#
from torch_sparse.transpose import t
#
from torch_sparse.narrow import narrow
#
from torch_sparse.select import select
#
from torch_sparse.index_select import index_select, index_select_nnz
#
from torch_sparse.masked_select import masked_select, masked_select_nnz
#
import torch_sparse.reduce
#
from torch_sparse.diag import remove_diag, set_diag
#
from torch_sparse.matmul import matmul
#
from torch_sparse.add import add, add_, add_nnz, add_nnz_
#
from torch_sparse.mul import mul, mul_, mul_nnz, mul_nnz_
from
torch_sparse.utils
import
is_scalar
@
torch
.
jit
.
script
class
SparseTensor
(
object
):
def
__init__
(
self
,
row
=
None
,
rowptr
=
None
,
col
=
None
,
value
=
None
,
sparse_size
=
None
,
is_sorted
=
False
):
storage
:
SparseStorage
def
__init__
(
self
,
row
:
Optional
[
torch
.
Tensor
]
=
None
,
rowptr
:
Optional
[
torch
.
Tensor
]
=
None
,
col
:
Optional
[
torch
.
Tensor
]
=
None
,
value
:
Optional
[
torch
.
Tensor
]
=
None
,
sparse_sizes
:
List
[
int
]
=
None
,
is_sorted
:
bool
=
False
):
self
.
storage
=
SparseStorage
(
row
=
row
,
rowptr
=
rowptr
,
col
=
col
,
value
=
value
,
sparse_size
=
sparse_size
,
value
=
value
,
sparse_sizes
=
sparse_sizes
,
rowcount
=
None
,
colptr
=
None
,
colcount
=
None
,
csr2csc
=
None
,
csc2csr
=
None
,
is_sorted
=
is_sorted
)
@
classmethod
def
from_storage
(
self
,
storage
):
def
from_storage
(
self
,
storage
:
SparseStorage
):
self
=
SparseTensor
.
__new__
(
SparseTensor
)
self
.
storage
=
storage
return
self
@
classmethod
def
from_dense
(
self
,
mat
):
def
from_dense
(
self
,
mat
:
torch
.
Tensor
):
if
mat
.
dim
()
>
2
:
index
=
mat
.
abs
().
sum
([
i
for
i
in
range
(
2
,
mat
.
dim
())]).
nonzero
()
else
:
index
=
mat
.
nonzero
()
index
=
index
.
t
()
row
,
col
=
index
.
t
().
contiguous
()
return
SparseTensor
(
row
=
row
,
col
=
col
,
value
=
mat
[
row
,
col
],
sparse_size
=
mat
.
size
()[:
2
],
is_sorted
=
True
)
row
,
col
=
index
[
0
],
index
[
1
]
return
SparseTensor
(
row
=
row
,
rowptr
=
None
,
col
=
col
,
value
=
mat
[
row
,
col
],
sparse_size
s
=
mat
.
size
()[:
2
],
is_sorted
=
True
)
@
classmethod
def
from_torch_sparse_coo_tensor
(
self
,
mat
,
is_sorted
=
False
):
row
,
col
=
mat
.
_indices
()
return
SparseTensor
(
row
=
row
,
col
=
col
,
value
=
mat
.
_values
(),
sparse_size
=
mat
.
size
()[:
2
],
is_sorted
=
is_sorted
)
def
from_torch_sparse_coo_tensor
(
self
,
mat
:
torch
.
Tensor
):
mat
=
mat
.
coalesce
()
index
=
mat
.
_indices
()
row
,
col
=
index
[
0
],
index
[
1
]
return
SparseTensor
(
row
=
row
,
rowptr
=
None
,
col
=
col
,
value
=
mat
.
_values
(),
sparse_sizes
=
mat
.
size
()[:
2
],
is_sorted
=
True
)
@
classmethod
def
from_scipy
(
self
,
mat
):
colptr
=
None
if
isinstance
(
mat
,
scipy
.
sparse
.
csc_matrix
):
colptr
=
torch
.
from_numpy
(
mat
.
indptr
).
to
(
torch
.
long
)
mat
=
mat
.
tocsr
()
# Pre-sort.
rowptr
=
torch
.
from_numpy
(
mat
.
indptr
).
to
(
torch
.
long
)
mat
=
mat
.
tocoo
()
row
=
torch
.
from_numpy
(
mat
.
row
).
to
(
torch
.
long
)
col
=
torch
.
from_numpy
(
mat
.
col
).
to
(
torch
.
long
)
value
=
torch
.
from_numpy
(
mat
.
data
)
sparse_size
=
mat
.
shape
[:
2
]
storage
=
SparseStorage
(
row
=
row
,
rowptr
=
rowptr
,
col
=
col
,
value
=
value
,
sparse_size
=
sparse_size
,
colptr
=
colptr
,
is_sorted
=
True
)
def
eye
(
self
,
M
:
int
,
N
:
Optional
[
int
]
=
None
,
options
:
Optional
[
torch
.
Tensor
]
=
None
,
has_value
:
bool
=
True
,
fill_cache
:
bool
=
False
):
return
SparseTensor
.
from_storage
(
storage
)
@
classmethod
def
eye
(
self
,
M
,
N
=
None
,
device
=
None
,
dtype
=
None
,
has_value
=
True
,
fill_cache
=
False
):
N
=
M
if
N
is
None
else
N
row
=
torch
.
arange
(
min
(
M
,
N
),
device
=
device
)
row
ptr
=
torch
.
arange
(
M
+
1
,
device
=
device
)
if
M
>
N
:
row
ptr
[
row
.
size
(
0
)
+
1
:]
=
row
.
size
(
0
)
if
options
is
not
None
:
row
=
torch
.
arange
(
min
(
M
,
N
)
,
device
=
options
.
device
)
else
:
row
=
torch
.
arange
(
min
(
M
,
N
)
)
col
=
row
value
=
None
rowptr
=
torch
.
arange
(
M
+
1
,
dtype
=
torch
.
long
,
device
=
row
.
device
)
if
M
>
N
:
rowptr
[
N
+
1
:]
=
M
value
:
Optional
[
torch
.
Tensor
]
=
None
if
has_value
:
value
=
torch
.
ones
(
row
.
size
(
0
),
dtype
=
dtype
,
device
=
device
)
if
options
is
not
None
:
value
=
torch
.
ones
(
row
.
numel
(),
dtype
=
options
.
dtype
,
device
=
row
.
device
)
else
:
value
=
torch
.
ones
(
row
.
numel
(),
device
=
row
.
device
)
rowcount
:
Optional
[
torch
.
Tensor
]
=
None
colptr
:
Optional
[
torch
.
Tensor
]
=
None
colcount
:
Optional
[
torch
.
Tensor
]
=
None
csr2csc
:
Optional
[
torch
.
Tensor
]
=
None
csc2csr
:
Optional
[
torch
.
Tensor
]
=
None
rowcount
=
colptr
=
colcount
=
csr2csc
=
csc2csr
=
None
if
fill_cache
:
rowcount
=
row
.
new_ones
(
M
)
rowcount
=
torch
.
ones
(
M
,
dtype
=
torch
.
long
,
device
=
row
.
device
)
if
M
>
N
:
rowcount
[
row
.
size
(
0
):]
=
0
colptr
=
torch
.
arange
(
N
+
1
,
device
=
device
)
colcount
=
col
.
new_ones
(
N
)
rowcount
[
N
:]
=
0
colptr
=
torch
.
arange
(
N
+
1
,
dtype
=
torch
.
long
,
device
=
row
.
device
)
colcount
=
torch
.
ones
(
N
,
dtype
=
torch
.
long
,
device
=
row
.
device
)
if
N
>
M
:
colptr
[
col
.
size
(
0
)
+
1
:]
=
col
.
size
(
0
)
colcount
[
col
.
size
(
0
)
:]
=
0
colptr
[
M
+
1
:]
=
M
colcount
[
M
:]
=
0
csr2csc
=
csc2csr
=
row
storage
=
SparseStorage
(
row
=
row
,
rowptr
=
rowptr
,
col
=
col
,
value
=
value
,
sparse_size
=
torch
.
Size
([
M
,
N
]),
rowcount
=
rowcount
,
colptr
=
colptr
,
colcount
=
colcount
,
csr2csc
=
csr2csc
,
csc2csr
=
csc2csr
,
is_sorted
=
True
)
return
SparseTensor
.
from_storage
(
storage
)
storage
:
SparseStorage
=
SparseStorage
(
row
=
row
,
rowptr
=
rowptr
,
col
=
col
,
value
=
value
,
sparse_sizes
=
torch
.
Size
([
M
,
N
]),
rowcount
=
rowcount
,
colptr
=
colptr
,
colcount
=
colcount
,
csr2csc
=
csr2csc
,
csc2csr
=
csc2csr
,
is_sorted
=
True
)
def
__copy__
(
self
):
self
=
SparseTensor
.
__new__
(
SparseTensor
)
self
.
storage
=
storage
return
self
def
copy
(
self
):
return
self
.
from_storage
(
self
.
storage
)
def
clone
(
self
):
return
self
.
from_storage
(
self
.
storage
.
clone
())
def
__deepcopy__
(
self
,
memo
):
new_sparse_tensor
=
self
.
clone
()
memo
[
id
(
self
)]
=
new_sparse_tensor
return
new_sparse_tensor
def
type_as
(
self
,
tensor
=
torch
.
Tensor
):
value
=
self
.
storage
.
_value
if
value
is
None
or
tensor
.
dtype
==
value
.
dtype
:
return
self
return
self
.
from_storage
(
self
.
storage
.
type_as
(
tensor
))
def
device_as
(
self
,
tensor
:
torch
.
Tensor
,
non_blocking
:
bool
=
False
):
if
tensor
.
device
==
self
.
device
():
return
self
return
self
.
from_storage
(
self
.
storage
.
device_as
(
tensor
,
non_blocking
))
# Formats #################################################################
def
coo
(
self
):
return
self
.
storage
.
row
,
self
.
storage
.
col
,
self
.
storage
.
value
def
coo
(
self
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]
:
return
self
.
storage
.
row
()
,
self
.
storage
.
col
()
,
self
.
storage
.
value
()
def
csr
(
self
):
return
self
.
storage
.
rowptr
,
self
.
storage
.
col
,
self
.
storage
.
value
def
csr
(
self
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]
:
return
self
.
storage
.
rowptr
()
,
self
.
storage
.
col
()
,
self
.
storage
.
value
()
def
csc
(
self
):
perm
=
self
.
storage
.
csr2csc
# Compute `csr2csc` first.
return
(
self
.
storage
.
colptr
,
self
.
storage
.
row
[
perm
],
self
.
storage
.
value
[
perm
]
if
self
.
has_value
()
else
None
)
def
csc
(
self
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
perm
=
self
.
storage
.
csr2csc
()
value
=
self
.
storage
.
value
()
if
value
is
not
None
:
value
=
value
[
perm
]
return
self
.
storage
.
colptr
(),
self
.
storage
.
row
()[
perm
],
value
# Storage inheritance #####################################################
def
has_value
(
self
):
def
has_value
(
self
)
->
bool
:
return
self
.
storage
.
has_value
()
def
set_value_
(
self
,
value
,
layout
=
None
,
dtype
=
None
):
self
.
storage
.
set_value_
(
value
,
layout
,
dtype
)
def
set_value_
(
self
,
value
:
Optional
[
torch
.
Tensor
],
layout
:
Optional
[
str
]
=
None
):
self
.
storage
.
set_value_
(
value
,
layout
)
return
self
def
set_value
(
self
,
value
,
layout
=
None
,
dtype
=
None
):
return
self
.
from_storage
(
self
.
storage
.
set_value
(
value
,
layout
,
dtype
))
def
set_value
(
self
,
value
:
Optional
[
torch
.
Tensor
],
layout
:
Optional
[
str
]
=
None
):
return
self
.
from_storage
(
self
.
storage
.
set_value
(
value
,
layout
))
def
sparse_sizes
(
self
)
->
List
[
int
]:
return
self
.
storage
.
sparse_sizes
()
def
sparse_size
(
self
,
dim
=
None
):
sparse_size
=
self
.
storage
.
sparse_size
return
sparse_size
if
dim
is
None
else
sparse_size
[
dim
]
def
sparse_size
(
self
,
dim
:
int
)
->
int
:
return
self
.
storage
.
sparse_sizes
()[
dim
]
def
sparse_resize
(
self
,
*
sizes
):
return
self
.
from_storage
(
self
.
storage
.
sparse_resize
(
*
sizes
))
def
sparse_resize
(
self
,
sparse_sizes
:
List
[
int
]
):
return
self
.
from_storage
(
self
.
storage
.
sparse_resize
(
sparse_
sizes
))
def
is_coalesced
(
self
):
def
is_coalesced
(
self
)
->
bool
:
return
self
.
storage
.
is_coalesced
()
def
coalesce
(
self
,
reduce
=
'
add
'
):
def
coalesce
(
self
,
reduce
:
str
=
"
add
"
):
return
self
.
from_storage
(
self
.
storage
.
coalesce
(
reduce
))
def
cached_keys
(
self
):
return
self
.
storage
.
cached_keys
()
def
fill_cache_
(
self
,
*
args
):
self
.
storage
.
fill_cache_
(
*
args
)
def
fill_cache_
(
self
):
self
.
storage
.
fill_cache_
()
return
self
def
clear_cache_
(
self
,
*
args
):
self
.
storage
.
clear_cache_
(
*
args
)
def
clear_cache_
(
self
):
self
.
storage
.
clear_cache_
()
return
self
# Utility functions #######################################################
def
dim
(
self
):
return
len
(
self
.
size
())
def
size
(
self
,
dim
=
None
):
size
=
self
.
sparse_size
()
size
+=
self
.
storage
.
value
.
size
()[
1
:]
if
self
.
has_value
()
else
()
return
size
if
dim
is
None
else
size
[
dim
]
@
property
def
shape
(
self
):
return
self
.
size
()
def
nnz
(
self
):
return
self
.
storage
.
col
.
numel
()
def
fill_value_
(
self
,
fill_value
:
float
,
options
:
Optional
[
torch
.
Tensor
]
=
None
):
if
options
is
not
None
:
value
=
torch
.
full
((
self
.
nnz
(),
),
fill_value
,
dtype
=
options
.
dtype
,
device
=
self
.
device
())
else
:
value
=
torch
.
full
((
self
.
nnz
(),
),
fill_value
,
device
=
self
.
device
())
return
self
.
set_value_
(
value
,
layout
=
'coo'
)
def
fill_value
(
self
,
fill_value
:
float
,
options
:
Optional
[
torch
.
Tensor
]
=
None
):
if
options
is
not
None
:
value
=
torch
.
full
((
self
.
nnz
(),
),
fill_value
,
dtype
=
options
.
dtype
,
device
=
self
.
device
())
else
:
value
=
torch
.
full
((
self
.
nnz
(),
),
fill_value
,
device
=
self
.
device
())
return
self
.
set_value
(
value
,
layout
=
'coo'
)
def
sizes
(
self
)
->
List
[
int
]:
sizes
=
self
.
sparse_sizes
()
value
=
self
.
storage
.
value
()
if
value
is
not
None
:
sizes
+=
value
.
size
()[
1
:]
return
sizes
def
size
(
self
,
dim
:
int
)
->
int
:
return
self
.
sizes
()[
dim
]
def
dim
(
self
)
->
int
:
return
len
(
self
.
sizes
())
def
nnz
(
self
)
->
int
:
return
self
.
storage
.
col
().
numel
()
def
numel
(
self
)
->
int
:
value
=
self
.
storage
.
value
()
if
value
is
not
None
:
return
value
.
numel
()
else
:
return
self
.
nnz
()
def
density
(
self
):
def
density
(
self
)
->
float
:
return
self
.
nnz
()
/
(
self
.
sparse_size
(
0
)
*
self
.
sparse_size
(
1
))
def
sparsity
(
self
):
def
sparsity
(
self
)
->
float
:
return
1
-
self
.
density
()
def
avg_row_length
(
self
):
def
avg_row_length
(
self
)
->
float
:
return
self
.
nnz
()
/
self
.
sparse_size
(
0
)
def
avg_col_length
(
self
):
def
avg_col_length
(
self
)
->
float
:
return
self
.
nnz
()
/
self
.
sparse_size
(
1
)
def
numel
(
self
):
return
self
.
value
.
numel
()
if
self
.
has_value
()
else
self
.
nnz
()
def
is_quadratic
(
self
):
def
is_quadratic
(
self
)
->
bool
:
return
self
.
sparse_size
(
0
)
==
self
.
sparse_size
(
1
)
def
is_symmetric
(
self
):
if
not
self
.
is_quadratic
:
def
is_symmetric
(
self
)
->
bool
:
if
not
self
.
is_quadratic
()
:
return
False
rowptr
,
col
,
value1
=
self
.
csr
()
...
...
@@ -207,296 +252,353 @@ class SparseTensor(object):
if
(
rowptr
!=
colptr
).
any
()
or
(
col
!=
row
).
any
():
return
False
if
not
self
.
has_value
()
:
if
value1
is
None
or
value2
is
None
:
return
True
return
(
value1
==
value2
).
all
()
.
item
(
)
else
:
return
bool
(
(
value1
==
value2
).
all
())
def
detach_
(
self
):
self
.
storage
.
apply_
(
lambda
x
:
x
.
detach_
())
value
=
self
.
storage
.
value
()
if
value
is
not
None
:
value
.
detach_
()
return
self
def
detach
(
self
):
return
self
.
from_storage
(
self
.
storage
.
apply
(
lambda
x
:
x
.
detach
()))
@
property
def
requires_grad
(
self
):
return
self
.
storage
.
value
.
requires_grad
if
self
.
has_value
()
else
False
value
=
self
.
storage
.
value
()
if
value
is
not
None
:
value
=
value
.
detach
()
return
self
.
set_value
(
value
,
layout
=
'coo'
)
def
requires_grad
(
self
)
->
bool
:
value
=
self
.
storage
.
value
()
if
value
is
not
None
:
return
value
.
requires_grad
else
:
return
False
def
requires_grad_
(
self
,
requires_grad
=
True
,
dtype
=
None
):
def
requires_grad_
(
self
,
requires_grad
:
bool
=
True
,
options
:
Optional
[
torch
.
Tensor
]
=
None
):
if
requires_grad
and
not
self
.
has_value
():
self
.
storage
.
set_value_
(
1
,
dtype
=
dtype
)
if
self
.
has_value
():
self
.
storage
.
value
.
requires_grad_
(
requires_grad
)
self
.
fill_value_
(
1.
,
options
=
options
)
value
=
self
.
storage
.
value
()
if
value
is
not
None
:
value
.
requires_grad_
(
requires_grad
)
return
self
def
pin_memory
(
self
):
return
self
.
from_storage
(
self
.
storage
.
apply
(
lambda
x
:
x
.
pin_memory
())
)
return
self
.
from_storage
(
self
.
storage
.
pin_memory
())
def
is_pinned
(
self
):
return
all
(
self
.
storage
.
map
(
lambda
x
:
x
.
is_pinned
()
))
def
is_pinned
(
self
)
->
bool
:
return
self
.
storage
.
is_pinned
()
def
share_memory_
(
self
)
:
self
.
storage
.
apply_
(
lambda
x
:
x
.
share_memory_
()
)
return
self
def
is_shared
(
self
)
:
return
all
(
self
.
storage
.
map
(
lambda
x
:
x
.
is_shared
())
)
def
options
(
self
)
->
torch
.
Tensor
:
value
=
self
.
storage
.
value
()
if
value
is
not
None
:
return
value
else
:
return
torch
.
tensor
(
0.
,
device
=
self
.
storage
.
col
().
device
)
@
property
def
device
(
self
):
return
self
.
storage
.
col
.
device
return
self
.
storage
.
col
()
.
device
def
cpu
(
self
):
return
self
.
from_storage
(
self
.
storage
.
apply
(
lambda
x
:
x
.
cpu
()))
def
cuda
(
self
,
device
=
None
,
non_blocking
=
False
,
**
kwargs
):
storage
=
self
.
storage
.
apply
(
lambda
x
:
x
.
cuda
(
device
,
non_blocking
,
**
kwargs
))
return
self
.
from_storage
(
storage
)
@
property
def
is_cuda
(
self
):
return
self
.
storage
.
col
.
is_cuda
@
property
def
dtype
(
self
):
return
self
.
storage
.
value
.
dtype
if
self
.
has_value
()
else
None
def
is_floating_point
(
self
):
value
=
self
.
storage
.
value
return
self
.
has_value
()
and
torch
.
is_floating_point
(
value
)
def
type
(
self
,
dtype
=
None
,
non_blocking
=
False
,
**
kwargs
):
if
dtype
is
None
:
return
self
.
dtype
if
dtype
==
self
.
dtype
:
return
self
return
self
.
device_as
(
torch
.
tensor
(
0.
),
non_blocking
=
False
)
storage
=
self
.
storage
.
apply_value
(
lambda
x
:
x
.
type
(
dtype
,
non_blocking
,
**
kwargs
))
return
self
.
from_storage
(
storage
)
def
to
(
self
,
*
args
,
**
kwargs
):
args
=
list
(
args
)
non_blocking
=
getattr
(
kwargs
,
'non_blocking'
,
False
)
storage
=
None
if
'device'
in
kwargs
:
device
=
kwargs
[
'device'
]
del
kwargs
[
'device'
]
storage
=
self
.
storage
.
apply
(
lambda
x
:
x
.
to
(
device
,
non_blocking
=
non_blocking
))
def
cuda
(
self
,
options
=
Optional
[
torch
.
Tensor
],
non_blocking
:
bool
=
False
):
if
options
is
not
None
:
return
self
.
device_as
(
options
,
non_blocking
)
else
:
for
arg
in
args
[:]:
if
isinstance
(
arg
,
str
)
or
isinstance
(
arg
,
torch
.
device
):
storage
=
self
.
storage
.
apply
(
lambda
x
:
x
.
to
(
arg
,
non_blocking
=
non_blocking
))
args
.
remove
(
arg
)
options
=
torch
.
tensor
(
0.
).
cuda
()
return
self
.
device_as
(
options
,
non_blocking
)
storage
=
self
.
storage
if
storage
is
None
else
storage
def
is_cuda
(
self
)
->
bool
:
return
self
.
storage
.
col
().
is_cuda
if
len
(
args
)
>
0
or
len
(
kwargs
)
>
0
:
storage
=
storage
.
apply_value
(
lambda
x
:
x
.
type
(
*
args
,
**
kwargs
))
def
dtype
(
self
)
:
return
self
.
options
().
dtype
if
storage
==
self
.
storage
:
# Nothing has been changed...
return
self
else
:
return
self
.
from_storage
(
storage
)
def
is_floating_point
(
self
)
->
bool
:
return
torch
.
is_floating_point
(
self
.
options
())
def
bfloat16
(
self
):
return
self
.
type
(
torch
.
bfloat16
)
return
self
.
type
_as
(
torch
.
tensor
(
0
,
dtype
=
torch
.
bfloat16
)
)
def
bool
(
self
):
return
self
.
type
(
torch
.
bool
)
return
self
.
type
_as
(
torch
.
tensor
(
0
,
dtype
=
torch
.
bool
)
)
def
byte
(
self
):
return
self
.
type
(
torch
.
by
te
)
return
self
.
type
_as
(
torch
.
te
nsor
(
0
,
dtype
=
torch
.
uint8
)
)
def
char
(
self
):
return
self
.
type
(
torch
.
char
)
return
self
.
type
_as
(
torch
.
tensor
(
0
,
dtype
=
torch
.
int8
)
)
def
half
(
self
):
return
self
.
type
(
torch
.
half
)
return
self
.
type
_as
(
torch
.
tensor
(
0
,
dtype
=
torch
.
half
)
)
def
float
(
self
):
return
self
.
type
(
torch
.
float
)
return
self
.
type
_as
(
torch
.
tensor
(
0
,
dtype
=
torch
.
float
)
)
def
double
(
self
):
return
self
.
type
(
torch
.
double
)
return
self
.
type
_as
(
torch
.
tensor
(
0
,
dtype
=
torch
.
double
)
)
def
short
(
self
):
return
self
.
type
(
torch
.
short
)
return
self
.
type
_as
(
torch
.
tensor
(
0
,
dtype
=
torch
.
short
)
)
def
int
(
self
):
return
self
.
type
(
torch
.
int
)
return
self
.
type
_as
(
torch
.
tensor
(
0
,
dtype
=
torch
.
int
)
)
def
long
(
self
):
return
self
.
type
(
torch
.
long
)
return
self
.
type
_as
(
torch
.
tensor
(
0
,
dtype
=
torch
.
long
)
)
# Conversions #############################################################
def
to_dense
(
self
,
dtype
=
None
):
dtype
=
dtype
or
self
.
dtype
def
to_dense
(
self
,
options
:
Optional
[
torch
.
Tensor
]
=
None
):
row
,
col
,
value
=
self
.
coo
()
mat
=
torch
.
zeros
(
self
.
size
(),
dtype
=
dtype
,
device
=
self
.
device
)
mat
[
row
,
col
]
=
value
if
self
.
has_value
()
else
1
if
options
is
not
None
:
mat
=
torch
.
zeros
(
self
.
sizes
(),
dtype
=
options
.
dtype
,
device
=
self
.
device
())
else
:
mat
=
torch
.
zeros
(
self
.
sizes
(),
device
=
self
.
device
())
if
value
is
not
None
:
mat
[
row
,
col
]
=
value
else
:
mat
[
row
,
col
]
=
torch
.
ones
(
self
.
nnz
(),
dtype
=
mat
.
dtype
,
device
=
mat
.
device
)
return
mat
def
to_torch_sparse_coo_tensor
(
self
,
dtype
=
None
,
requires_grad
=
False
):
def
to_torch_sparse_coo_tensor
(
self
,
options
:
Optional
[
torch
.
Tensor
]
):
row
,
col
,
value
=
self
.
coo
()
index
=
torch
.
stack
([
row
,
col
],
dim
=
0
)
if
value
is
None
:
value
=
torch
.
ones
(
self
.
nnz
(),
dtype
=
dtype
,
device
=
self
.
device
)
return
torch
.
sparse_coo_tensor
(
index
,
value
,
self
.
size
(),
device
=
self
.
device
,
requires_grad
=
requires_grad
)
def
to_scipy
(
self
,
layout
=
None
,
dtype
=
None
):
assert
self
.
dim
()
==
2
layout
=
get_layout
(
layout
)
if
not
self
.
has_value
():
ones
=
torch
.
ones
(
self
.
nnz
(),
dtype
=
dtype
).
numpy
()
if
layout
==
'coo'
:
row
,
col
,
value
=
self
.
coo
()
row
=
row
.
detach
().
cpu
().
numpy
()
col
=
col
.
detach
().
cpu
().
numpy
()
value
=
value
.
detach
().
cpu
().
numpy
()
if
self
.
has_value
()
else
ones
return
scipy
.
sparse
.
coo_matrix
((
value
,
(
row
,
col
)),
self
.
size
())
elif
layout
==
'csr'
:
rowptr
,
col
,
value
=
self
.
csr
()
rowptr
=
rowptr
.
detach
().
cpu
().
numpy
()
col
=
col
.
detach
().
cpu
().
numpy
()
value
=
value
.
detach
().
cpu
().
numpy
()
if
self
.
has_value
()
else
ones
return
scipy
.
sparse
.
csr_matrix
((
value
,
col
,
rowptr
),
self
.
size
())
elif
layout
==
'csc'
:
colptr
,
row
,
value
=
self
.
csc
()
colptr
=
colptr
.
detach
().
cpu
().
numpy
()
row
=
row
.
detach
().
cpu
().
numpy
()
value
=
value
.
detach
().
cpu
().
numpy
()
if
self
.
has_value
()
else
ones
return
scipy
.
sparse
.
csc_matrix
((
value
,
row
,
colptr
),
self
.
size
())
# Standard Operators ######################################################
def
__getitem__
(
self
,
index
):
index
=
list
(
index
)
if
isinstance
(
index
,
tuple
)
else
[
index
]
# More than one `Ellipsis` is not allowed...
if
len
([
i
for
i
in
index
if
not
torch
.
is_tensor
(
i
)
and
i
==
...])
>
1
:
raise
SyntaxError
dim
=
0
out
=
self
while
len
(
index
)
>
0
:
item
=
index
.
pop
(
0
)
if
isinstance
(
item
,
int
):
out
=
out
.
select
(
dim
,
item
)
dim
+=
1
elif
isinstance
(
item
,
slice
):
if
item
.
step
is
not
None
:
raise
ValueError
(
'Step parameter not yet supported.'
)
start
=
0
if
item
.
start
is
None
else
item
.
start
start
=
self
.
size
(
dim
)
+
start
if
start
<
0
else
start
stop
=
self
.
size
(
dim
)
if
item
.
stop
is
None
else
item
.
stop
stop
=
self
.
size
(
dim
)
+
stop
if
stop
<
0
else
stop
out
=
out
.
narrow
(
dim
,
start
,
max
(
stop
-
start
,
0
))
dim
+=
1
elif
torch
.
is_tensor
(
item
):
if
item
.
dtype
==
torch
.
bool
:
out
=
out
.
masked_select
(
dim
,
item
)
dim
+=
1
elif
item
.
dtype
==
torch
.
long
:
out
=
out
.
index_select
(
dim
,
item
)
dim
+=
1
elif
item
==
Ellipsis
:
if
self
.
dim
()
-
len
(
index
)
<
dim
:
raise
SyntaxError
dim
=
self
.
dim
()
-
len
(
index
)
if
options
is
not
None
:
value
=
torch
.
ones
(
self
.
nnz
(),
dtype
=
options
.
dtype
,
device
=
self
.
device
())
else
:
raise
SyntaxError
value
=
torch
.
ones
(
self
.
nnz
(),
device
=
self
.
device
())
return
out
return
torch
.
sparse_coo_tensor
(
index
,
value
,
self
.
sizes
())
def
__add__
(
self
,
other
):
return
self
.
add
(
other
)
def
__radd__
(
self
,
other
):
return
self
.
add
(
other
)
# # Standard Operators ######################################################
def
__iadd__
(
self
,
other
):
return
self
.
add_
(
other
)
# def __getitem__(self, index):
# index = list(index) if isinstance(index, tuple) else [index]
# # More than one `Ellipsis` is not allowed...
# if len([i for i in index if not torch.is_tensor(i) and i == ...]) > 1:
# raise SyntaxError
def
__mul__
(
self
,
other
):
return
self
.
mul
(
other
)
# dim = 0
# out = self
# while len(index) > 0:
# item = index.pop(0)
# if isinstance(item, int):
# out = out.select(dim, item)
# dim += 1
# elif isinstance(item, slice):
# if item.step is not None:
# raise ValueError('Step parameter not yet supported.')
def
__rmul__
(
self
,
other
):
return
self
.
mul
(
other
)
#
start = 0 if item.start is None else item.start
#
start = self.size(dim) + start if start < 0 else start
def
__imul__
(
self
,
other
):
return
self
.
mul_
(
other
)
# stop = self.size(dim) if item.stop is None else item.stop
#
stop = self.size(dim) + stop if stop < 0 else stop
def
__matmul__
(
self
,
other
):
return
matmul
(
self
,
other
,
reduce
=
'sum'
)
# out = out.narrow(dim, start, max(stop - start, 0))
# dim += 1
# elif torch.is_tensor(item):
# if item.dtype == torch.bool:
# out = out.masked_select(dim, item)
# dim += 1
# elif item.dtype == torch.long:
# out = out.index_select(dim, item)
# dim += 1
# elif item == Ellipsis:
# if self.dim() - len(index) < dim:
# raise SyntaxError
# dim = self.dim() - len(index)
# else:
# raise SyntaxError
# String Reputation #######################################################
#
return out
def
__repr__
(
self
):
i
=
' '
*
6
row
,
col
,
value
=
self
.
coo
()
infos
=
[]
infos
+=
[
f
'row=
{
indent
(
row
.
__repr__
(),
i
)[
len
(
i
):]
}
'
]
infos
+=
[
f
'col=
{
indent
(
col
.
__repr__
(),
i
)[
len
(
i
):]
}
'
]
# def __add__(self, other):
# return self.add(other)
# def __radd__(self, other):
# return self.add(other)
# def __iadd__(self, other):
# return self.add_(other)
# def __mul__(self, other):
# return self.mul(other)
# def __rmul__(self, other):
# return self.mul(other)
if
self
.
has_value
(
):
infos
+=
[
f
'val=
{
indent
(
value
.
__repr__
(),
i
)[
len
(
i
):]
}
'
]
# def __imul__(self, other
):
#
return self.mul_(other)
infos
+=
[
f
'size=
{
tuple
(
self
.
size
())
}
, '
f
'nnz=
{
self
.
nnz
()
}
, '
f
'density=
{
100
*
self
.
density
():.
02
f
}
%'
]
infos
=
',
\n
'
.
join
(
infos
)
# def __matmul__(self, other):
# return matmul(self, other, reduce='sum')
i
=
' '
*
(
len
(
self
.
__class__
.
__name__
)
+
1
)
return
f
'
{
self
.
__class__
.
__name__
}
(
{
indent
(
infos
,
i
)[
len
(
i
):]
}
)'
# # String Reputation #######################################################
# def __repr__(self):
# i = ' ' * 6
# row, col, value = self.coo()
# infos = []
# infos += [f'row={indent(row.__repr__(), i)[len(i):]}']
# infos += [f'col={indent(col.__repr__(), i)[len(i):]}']
# if self.has_value():
# infos += [f'val={indent(value.__repr__(), i)[len(i):]}']
# infos += [
# f'size={tuple(self.size())}, '
# f'nnz={self.nnz()}, '
# f'density={100 * self.density():.02f}%'
# ]
# infos = ',\n'.join(infos)
# i = ' ' * (len(self.__class__.__name__) + 1)
# return f'{self.__class__.__name__}({indent(infos, i)[len(i):]})'
# Bindings ####################################################################
SparseTensor
.
t
=
t
SparseTensor
.
narrow
=
narrow
SparseTensor
.
select
=
select
SparseTensor
.
index_select
=
index_select
SparseTensor
.
index_select_nnz
=
index_select_nnz
SparseTensor
.
masked_select
=
masked_select
SparseTensor
.
masked_select_nnz
=
masked_select_nnz
SparseTensor
.
reduction
=
torch_sparse
.
reduce
.
reduction
SparseTensor
.
sum
=
torch_sparse
.
reduce
.
sum
SparseTensor
.
mean
=
torch_sparse
.
reduce
.
mean
SparseTensor
.
min
=
torch_sparse
.
reduce
.
min
SparseTensor
.
max
=
torch_sparse
.
reduce
.
max
SparseTensor
.
remove_diag
=
remove_diag
SparseTensor
.
set_diag
=
set_diag
SparseTensor
.
matmul
=
matmul
SparseTensor
.
add
=
add
SparseTensor
.
add_
=
add_
SparseTensor
.
add_nnz
=
add_nnz
SparseTensor
.
add_nnz_
=
add_nnz_
SparseTensor
.
mul
=
mul
SparseTensor
.
mul_
=
mul_
SparseTensor
.
mul_nnz
=
mul_nnz
SparseTensor
.
mul_nnz_
=
mul_nnz_
# Fix for PyTorch<=1.3 (https://github.com/pytorch/pytorch/pull/31769):
# SparseTensor.t = t
# SparseTensor.narrow = narrow
# SparseTensor.select = select
# SparseTensor.index_select = index_select
# SparseTensor.index_select_nnz = index_select_nnz
# SparseTensor.masked_select = masked_select
# SparseTensor.masked_select_nnz = masked_select_nnz
# SparseTensor.reduction = torch_sparse.reduce.reduction
# SparseTensor.sum = torch_sparse.reduce.sum
# SparseTensor.mean = torch_sparse.reduce.mean
# SparseTensor.min = torch_sparse.reduce.min
# SparseTensor.max = torch_sparse.reduce.max
# SparseTensor.remove_diag = remove_diag
# SparseTensor.set_diag = set_diag
# SparseTensor.matmul = matmul
# SparseTensor.add = add
# SparseTensor.add_ = add_
# SparseTensor.add_nnz = add_nnz
# SparseTensor.add_nnz_ = add_nnz_
# SparseTensor.mul = mul
# SparseTensor.mul_ = mul_
# SparseTensor.mul_nnz = mul_nnz
# SparseTensor.mul_nnz_ = mul_nnz_
# Python Bindings #############################################################
Dtype
=
Optional
[
torch
.
dtype
]
Device
=
Optional
[
Union
[
torch
.
device
,
str
]]
@
torch
.
jit
.
ignore
def
share_memory_
(
self
:
SparseTensor
)
->
SparseTensor
:
self
.
storage
.
share_memory_
()
@
torch
.
jit
.
ignore
def
is_shared
(
self
:
SparseTensor
)
->
bool
:
return
self
.
storage
.
is_shared
()
@
torch
.
jit
.
ignore
def
to
(
self
,
*
args
,
**
kwargs
):
dtype
:
Dtype
=
getattr
(
kwargs
,
'dtype'
,
None
)
device
:
Device
=
getattr
(
kwargs
,
'device'
,
None
)
non_blocking
:
bool
=
getattr
(
kwargs
,
'non_blocking'
,
False
)
for
arg
in
args
:
if
isinstance
(
arg
,
str
)
or
isinstance
(
arg
,
torch
.
device
):
device
=
arg
if
isinstance
(
arg
,
torch
.
dtype
):
dtype
=
arg
if
dtype
is
not
None
:
self
=
self
.
type_as
(
torch
.
tensor
(
0.
,
dtype
=
dtype
))
if
device
is
not
None
:
self
=
self
.
device_as
(
torch
.
tensor
(
0.
,
device
=
device
),
non_blocking
)
return
self
SparseTensor
.
share_memory_
=
share_memory_
SparseTensor
.
is_shared
=
is_shared
SparseTensor
.
to
=
to
# Scipy Conversions ###########################################################
ScipySparseMatrix
=
Union
[
scipy
.
sparse
.
coo_matrix
,
scipy
.
sparse
.
csr_matrix
,
scipy
.
sparse
.
csc_matrix
]
@
torch
.
jit
.
ignore
def
from_scipy
(
mat
:
ScipySparseMatrix
)
->
SparseTensor
:
colptr
=
None
if
isinstance
(
mat
,
scipy
.
sparse
.
csc_matrix
):
colptr
=
torch
.
from_numpy
(
mat
.
indptr
).
to
(
torch
.
long
)
mat
=
mat
.
tocsr
()
rowptr
=
torch
.
from_numpy
(
mat
.
indptr
).
to
(
torch
.
long
)
mat
=
mat
.
tocoo
()
row
=
torch
.
from_numpy
(
mat
.
row
).
to
(
torch
.
long
)
col
=
torch
.
from_numpy
(
mat
.
col
).
to
(
torch
.
long
)
value
=
torch
.
from_numpy
(
mat
.
data
)
sparse_sizes
=
mat
.
shape
[:
2
]
storage
=
SparseStorage
(
row
=
row
,
rowptr
=
rowptr
,
col
=
col
,
value
=
value
,
sparse_sizes
=
sparse_sizes
,
rowcount
=
None
,
colptr
=
colptr
,
colcount
=
None
,
csr2csc
=
None
,
csc2csr
=
None
,
is_sorted
=
True
)
return
SparseTensor
.
from_storage
(
storage
)
@
torch
.
jit
.
ignore
def
to_scipy
(
self
:
SparseTensor
,
layout
:
Optional
[
str
]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
)
->
ScipySparseMatrix
:
assert
self
.
dim
()
==
2
layout
=
get_layout
(
layout
)
if
not
self
.
has_value
():
ones
=
torch
.
ones
(
self
.
nnz
(),
dtype
=
dtype
).
numpy
()
if
layout
==
'coo'
:
row
,
col
,
value
=
self
.
coo
()
row
=
row
.
detach
().
cpu
().
numpy
()
col
=
col
.
detach
().
cpu
().
numpy
()
value
=
value
.
detach
().
cpu
().
numpy
()
if
self
.
has_value
()
else
ones
return
scipy
.
sparse
.
coo_matrix
((
value
,
(
row
,
col
)),
self
.
sizes
())
elif
layout
==
'csr'
:
rowptr
,
col
,
value
=
self
.
csr
()
rowptr
=
rowptr
.
detach
().
cpu
().
numpy
()
col
=
col
.
detach
().
cpu
().
numpy
()
value
=
value
.
detach
().
cpu
().
numpy
()
if
self
.
has_value
()
else
ones
return
scipy
.
sparse
.
csr_matrix
((
value
,
col
,
rowptr
),
self
.
sizes
())
elif
layout
==
'csc'
:
colptr
,
row
,
value
=
self
.
csc
()
colptr
=
colptr
.
detach
().
cpu
().
numpy
()
row
=
row
.
detach
().
cpu
().
numpy
()
value
=
value
.
detach
().
cpu
().
numpy
()
if
self
.
has_value
()
else
ones
return
scipy
.
sparse
.
csc_matrix
((
value
,
row
,
colptr
),
self
.
sizes
())
SparseTensor
.
from_scipy
=
from_scipy
SparseTensor
.
to_scipy
=
to_scipy
# Hacky fixes #################################################################
# Fix standard operators of `torch.Tensor` for PyTorch<=1.3.
# https://github.com/pytorch/pytorch/pull/31769
TORCH_MAJOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
0
])
TORCH_MINOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
1
])
if
(
TORCH_MAJOR
<
1
)
or
(
TORCH_MAJOR
==
1
and
TORCH_MINOR
<
4
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment