Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
f59fe649
Commit
f59fe649
authored
Jan 27, 2020
by
rusty1s
Browse files
beginning of torch script support
parent
c4484dbb
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
391 additions
and
233 deletions
+391
-233
test/test_jit.py
test/test_jit.py
+80
-0
torch_sparse/coalesce.py
torch_sparse/coalesce.py
+2
-1
torch_sparse/diag.py
torch_sparse/diag.py
+7
-4
torch_sparse/matmul.py
torch_sparse/matmul.py
+8
-2
torch_sparse/spspmm.py
torch_sparse/spspmm.py
+4
-3
torch_sparse/storage.py
torch_sparse/storage.py
+282
-217
torch_sparse/utils.py
torch_sparse/utils.py
+8
-6
No files found.
test/test_jit.py
0 → 100644
View file @
f59fe649
import
torch
from
torch_sparse.tensor
import
SparseTensor
from
torch_sparse.storage
import
SparseStorage
from
typing
import
Dict
,
Any
# class MyTensor(dict):
# def __init__(self, rowptr, col):
# self['rowptr'] = rowptr
# self['col'] = col
# def rowptr(self: Dict[str, torch.Tensor]):
# return self['rowptr']
@
torch
.
jit
.
script
class
Foo
:
rowptr
:
torch
.
Tensor
col
:
torch
.
Tensor
def
__init__
(
self
,
rowptr
:
torch
.
Tensor
,
col
:
torch
.
Tensor
):
self
.
rowptr
=
rowptr
self
.
col
=
col
class
MyCell
(
torch
.
nn
.
Module
):
def
__init__
(
self
):
super
(
MyCell
,
self
).
__init__
()
self
.
linear
=
torch
.
nn
.
Linear
(
2
,
4
)
# def forward(self, x: torch.Tensor, ptr: torch.Tensor) -> torch.Tensor:
def
forward
(
self
,
x
:
torch
.
Tensor
,
adj
:
SparseStorage
)
->
torch
.
Tensor
:
out
,
_
=
torch
.
ops
.
torch_sparse_cpu
.
spmm
(
adj
.
rowptr
(),
adj
.
col
(),
None
,
x
,
'sum'
)
return
out
# ind = torch.ops.torch_sparse_cpu.ptr2ind(ptr, ptr[-1].item())
# # ind = ptr2ind(ptr, E)
# x_j = x[ind]
# out = self.linear(x_j)
# return out
def
test_jit
():
my_cell
=
MyCell
()
# x = torch.rand(3, 2)
# ptr = torch.tensor([0, 2, 4, 6])
# out = my_cell(x, ptr)
# print()
# print(out)
# traced_cell = torch.jit.trace(my_cell, (x, ptr))
# print(traced_cell)
# out = traced_cell(x, ptr)
# print(out)
x
=
torch
.
randn
(
3
,
2
)
# adj = torch.randn(3, 3)
# adj = SparseTensor.from_dense(adj)
# adj = Foo(adj.storage.rowptr, adj.storage.col)
# adj = adj.storage
rowptr
=
torch
.
tensor
([
0
,
3
,
6
,
9
])
col
=
torch
.
tensor
([
0
,
1
,
2
,
0
,
1
,
2
,
0
,
1
,
2
])
adj
=
SparseStorage
(
rowptr
=
rowptr
,
col
=
col
)
# adj = {'rowptr': mat.storage.rowptr, 'col': mat.storage.col}
# foo = Foo(mat.storage.rowptr, mat.storage.col)
# adj = MyTensor(mat.storage.rowptr, mat.storage.col)
traced_cell
=
torch
.
jit
.
script
(
my_cell
)
print
(
traced_cell
)
out
=
traced_cell
(
x
,
adj
)
print
(
out
)
# # print(traced_cell.code)
torch_sparse/coalesce.py
View file @
f59fe649
import
torch
import
torch_scatter
from
.unique
import
unique
#
from .unique import unique
def
coalesce
(
index
,
value
,
m
,
n
,
op
=
'add'
,
fill_value
=
0
):
...
...
@@ -22,6 +22,7 @@ def coalesce(index, value, m, n, op='add', fill_value=0):
:rtype: (:class:`LongTensor`, :class:`Tensor`)
"""
raise
NotImplementedError
row
,
col
=
index
...
...
torch_sparse/diag.py
View file @
f59fe649
import
torch
from
torch_sparse.utils
import
ext
def
remove_diag
(
src
,
k
=
0
):
row
,
col
,
value
=
src
.
coo
()
...
...
@@ -39,8 +37,13 @@ def set_diag(src, values=None, k=0):
row
,
col
,
value
=
src
.
coo
()
mask
=
ext
(
row
.
is_cuda
).
non_diag_mask
(
row
,
col
,
src
.
size
(
0
),
src
.
size
(
1
),
k
)
if
row
.
is_cuda
:
mask
=
torch
.
ops
.
torch_sparse_cuda
.
non_diag_mask
(
row
,
col
,
src
.
size
(
0
),
src
.
size
(
1
),
k
)
else
:
mask
=
torch
.
ops
.
torch_sparse_cpu
.
non_diag_mask
(
row
,
col
,
src
.
size
(
0
),
src
.
size
(
1
),
k
)
inv_mask
=
~
mask
start
,
num_diag
=
-
k
if
k
<
0
else
0
,
mask
.
numel
()
-
row
.
numel
()
...
...
torch_sparse/matmul.py
View file @
f59fe649
import
torch
import
scipy.sparse
from
torch_scatter
import
scatter_add
from
torch_sparse.utils
import
ext
ext
=
None
class
SPMM
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
row
,
rowptr
,
col
,
value
,
mat
,
rowcount
,
colptr
,
csr2csc
,
reduce
):
out
,
arg_out
=
ext
(
mat
.
is_cuda
).
spmm
(
rowptr
,
col
,
value
,
mat
,
reduce
)
if
mat
.
is_cuda
:
out
,
arg_out
=
torch
.
ops
.
torch_sparse_cuda
.
spmm
(
rowptr
,
col
,
value
,
mat
,
reduce
)
else
:
out
,
arg_out
=
torch
.
ops
.
torch_sparse_cpu
.
spmm
(
rowptr
,
col
,
value
,
mat
,
reduce
)
ctx
.
reduce
=
reduce
ctx
.
save_for_backward
(
row
,
rowptr
,
col
,
value
,
mat
,
rowcount
,
colptr
,
...
...
torch_sparse/spspmm.py
View file @
f59fe649
import
torch
from
torch_sparse
import
transpose
,
to_scipy
,
from_scipy
,
coalesce
import
torch_sparse.spspmm_cpu
#
import torch_sparse.spspmm_cpu
if
torch
.
cuda
.
is_available
():
import
torch_sparse.spspmm_cuda
#
if torch.cuda.is_available():
#
import torch_sparse.spspmm_cuda
def
spspmm
(
indexA
,
valueA
,
indexB
,
valueB
,
m
,
k
,
n
,
coalesced
=
False
):
...
...
@@ -25,6 +25,7 @@ def spspmm(indexA, valueA, indexB, valueB, m, k, n, coalesced=False):
:rtype: (:class:`LongTensor`, :class:`Tensor`)
"""
raise
NotImplementedError
if
indexA
.
is_cuda
and
coalesced
:
indexA
,
valueA
=
coalesce
(
indexA
,
valueA
,
m
,
k
)
indexB
,
valueB
=
coalesce
(
indexB
,
valueB
,
k
,
n
)
...
...
torch_sparse/storage.py
View file @
f59fe649
import
warnings
from
typing
import
Optional
,
List
,
Dict
,
Any
import
torch
from
torch_scatter
import
segment_csr
,
scatter_add
from
torch_sparse.utils
import
ext
from
torch_sparse.utils
import
Final
__cache__
=
{
'enabled'
:
True
}
...
...
@@ -32,24 +33,24 @@ class no_cache(object):
return
decorate_no_cache
class
cached_property
(
object
):
def
__init__
(
self
,
func
):
self
.
func
=
func
#
class cached_property(object):
#
def __init__(self, func):
#
self.func = func
def
__get__
(
self
,
obj
,
cls
):
value
=
getattr
(
obj
,
f
'_
{
self
.
func
.
__name__
}
'
,
None
)
if
value
is
None
:
value
=
self
.
func
(
obj
)
if
is_cache_enabled
():
setattr
(
obj
,
f
'_
{
self
.
func
.
__name__
}
'
,
value
)
return
value
#
def __get__(self, obj, cls):
#
value = getattr(obj, f'_{self.func.__name__}', None)
#
if value is None:
#
value = self.func(obj)
#
if is_cache_enabled():
#
setattr(obj, f'_{self.func.__name__}', value)
#
return value
def
optional
(
func
,
src
):
return
func
(
src
)
if
src
is
not
None
else
src
layouts
=
[
'coo'
,
'csr'
,
'csc'
]
layouts
:
Final
[
List
[
str
]]
=
[
'coo'
,
'csr'
,
'csc'
]
def
get_layout
(
layout
=
None
):
...
...
@@ -61,12 +62,30 @@ def get_layout(layout=None):
return
layout
@
torch
.
jit
.
script
class
SparseStorage
(
object
):
cache_keys
=
[
'rowcount'
,
'colptr'
,
'colcount'
,
'csr2csc'
,
'csc2csr'
]
def
__init__
(
self
,
row
=
None
,
rowptr
=
None
,
col
=
None
,
value
=
None
,
sparse_size
=
None
,
rowcount
=
None
,
colptr
=
None
,
colcount
=
None
,
csr2csc
=
None
,
csc2csr
=
None
,
is_sorted
=
False
):
_row
:
Optional
[
torch
.
Tensor
]
_rowptr
:
Optional
[
torch
.
Tensor
]
_col
:
torch
.
Tensor
_value
:
Optional
[
torch
.
Tensor
]
_sparse_size
:
List
[
int
]
_rowcount
:
Optional
[
torch
.
Tensor
]
_colptr
:
Optional
[
torch
.
Tensor
]
_colcount
:
Optional
[
torch
.
Tensor
]
_csr2csc
:
Optional
[
torch
.
Tensor
]
_csc2csr
:
Optional
[
torch
.
Tensor
]
def
__init__
(
self
,
row
:
Optional
[
torch
.
Tensor
]
=
None
,
rowptr
:
Optional
[
torch
.
Tensor
]
=
None
,
col
:
Optional
[
torch
.
Tensor
]
=
None
,
value
:
Optional
[
torch
.
Tensor
]
=
None
,
sparse_size
:
Optional
[
List
[
int
]]
=
None
,
rowcount
:
Optional
[
torch
.
Tensor
]
=
None
,
colptr
:
Optional
[
torch
.
Tensor
]
=
None
,
colcount
:
Optional
[
torch
.
Tensor
]
=
None
,
csr2csc
:
Optional
[
torch
.
Tensor
]
=
None
,
csc2csr
:
Optional
[
torch
.
Tensor
]
=
None
,
is_sorted
:
bool
=
False
):
assert
row
is
not
None
or
rowptr
is
not
None
assert
col
is
not
None
...
...
@@ -75,9 +94,16 @@ class SparseStorage(object):
col
=
col
.
contiguous
()
if
sparse_size
is
None
:
M
=
rowptr
.
numel
()
-
1
if
row
is
None
else
row
.
max
().
item
()
+
1
if
rowptr
is
not
None
:
M
=
rowptr
.
numel
()
-
1
elif
row
is
not
None
:
M
=
row
.
max
().
item
()
+
1
else
:
raise
ValueError
N
=
col
.
max
().
item
()
+
1
sparse_size
=
torch
.
Size
([
M
,
N
])
sparse_size
=
torch
.
Size
([
int
(
M
),
int
(
N
)])
else
:
assert
len
(
sparse_size
)
==
2
if
row
is
not
None
:
assert
row
.
dtype
==
torch
.
long
...
...
@@ -145,264 +171,303 @@ class SparseStorage(object):
self
.
_csc2csr
=
csc2csr
if
not
is_sorted
:
idx
=
self
.
col
.
new_zeros
(
col
.
numel
()
+
1
)
idx
[
1
:]
=
sparse_size
[
1
]
*
self
.
row
+
self
.
col
idx
=
col
.
new_zeros
(
col
.
numel
()
+
1
)
idx
[
1
:]
=
sparse_size
[
1
]
*
self
.
row
()
+
col
if
(
idx
[
1
:]
<
idx
[:
-
1
]).
any
():
perm
=
idx
[
1
:].
argsort
()
self
.
_row
=
self
.
row
[
perm
]
self
.
_col
=
self
.
col
[
perm
]
self
.
_value
=
self
.
value
[
perm
]
if
self
.
has_value
()
else
None
self
.
_row
=
self
.
row
()[
perm
]
self
.
_col
=
col
[
perm
]
if
value
is
not
None
:
self
.
_value
=
value
[
perm
]
self
.
_csr2csc
=
None
self
.
_csc2csr
=
None
def
has_row
(
self
):
def
has_row
(
self
)
->
bool
:
return
self
.
_row
is
not
None
@
property
def
row
(
self
):
if
self
.
_row
is
None
:
self
.
_row
=
ext
(
self
.
col
.
is_cuda
).
ptr2ind
(
self
.
rowptr
,
self
.
col
.
numel
())
return
self
.
_row
row
=
self
.
_row
if
row
is
not
None
:
return
row
def
has_rowptr
(
self
):
rowptr
=
self
.
_rowptr
if
rowptr
is
not
None
:
if
rowptr
.
is_cuda
:
row
=
torch
.
ops
.
torch_sparse_cuda
.
ptr2ind
(
rowptr
,
self
.
_col
.
numel
())
else
:
if
rowptr
.
is_cuda
:
row
=
torch
.
ops
.
torch_sparse_cuda
.
ptr2ind
(
rowptr
,
self
.
_col
.
numel
())
else
:
row
=
torch
.
ops
.
torch_sparse_cpu
.
ptr2ind
(
rowptr
,
self
.
_col
.
numel
())
self
.
_row
=
row
return
row
raise
ValueError
def
has_rowptr
(
self
)
->
bool
:
return
self
.
_rowptr
is
not
None
@
property
def
rowptr
(
self
):
if
self
.
_rowptr
is
None
:
self
.
_rowptr
=
ext
(
self
.
col
.
is_cuda
).
ind2ptr
(
self
.
row
,
self
.
sparse_size
[
0
])
return
self
.
_rowptr
def
rowptr
(
self
)
->
torch
.
Tensor
:
rowptr
=
self
.
_rowptr
if
rowptr
is
not
None
:
return
rowptr
@
property
def
col
(
self
):
row
=
self
.
_row
if
row
is
not
None
:
if
row
.
is_cuda
:
rowptr
=
torch
.
ops
.
torch_sparse_cuda
.
ind2ptr
(
row
,
self
.
_sparse_size
[
0
])
else
:
rowptr
=
torch
.
ops
.
torch_sparse_cpu
.
ind2ptr
(
row
,
self
.
_sparse_size
[
0
])
self
.
_rowptr
=
rowptr
return
rowptr
raise
ValueError
def
col
(
self
)
->
torch
.
Tensor
:
return
self
.
_col
def
has_value
(
self
):
def
has_value
(
self
)
->
bool
:
return
self
.
_value
is
not
None
@
property
def
value
(
self
):
def
value
(
self
)
->
Optional
[
torch
.
Tensor
]:
return
self
.
_value
def
set_value_
(
self
,
value
,
layout
=
None
,
dtype
=
None
):
if
isinstance
(
value
,
int
)
or
isinstance
(
value
,
float
):
value
=
torch
.
full
((
self
.
col
.
numel
(),
),
dtype
=
dtype
,
device
=
self
.
col
.
device
)
#
def set_value_(self, value, layout=None, dtype=None):
#
if isinstance(value, int) or isinstance(value, float):
#
value = torch.full((self.col.numel(), ), dtype=dtype,
#
device=self.col.device)
elif
torch
.
is_tensor
(
value
)
and
get_layout
(
layout
)
==
'csc'
:
value
=
value
[
self
.
csc2csr
]
#
elif torch.is_tensor(value) and get_layout(layout) == 'csc':
#
value = value[self.csc2csr]
if
torch
.
is_tensor
(
value
):
value
=
value
if
dtype
is
None
else
value
.
to
(
dtype
)
assert
value
.
device
==
self
.
col
.
device
assert
value
.
size
(
0
)
==
self
.
col
.
numel
()
#
if torch.is_tensor(value):
#
value = value if dtype is None else value.to(dtype)
#
assert value.device == self.col.device
#
assert value.size(0) == self.col.numel()
self
.
_value
=
value
return
self
#
self._value = value
#
return self
def
set_value
(
self
,
value
,
layout
=
None
,
dtype
=
None
):
if
isinstance
(
value
,
int
)
or
isinstance
(
value
,
float
):
value
=
torch
.
full
((
self
.
col
.
numel
(),
),
dtype
=
dtype
,
device
=
self
.
col
.
device
)
#
def set_value(self, value, layout=None, dtype=None):
#
if isinstance(value, int) or isinstance(value, float):
#
value = torch.full((self.col.numel(), ), dtype=dtype,
#
device=self.col.device)
elif
torch
.
is_tensor
(
value
)
and
get_layout
(
layout
)
==
'csc'
:
value
=
value
[
self
.
csc2csr
]
#
elif torch.is_tensor(value) and get_layout(layout) == 'csc':
#
value = value[self.csc2csr]
if
torch
.
is_tensor
(
value
):
value
=
value
if
dtype
is
None
else
value
.
to
(
dtype
)
assert
value
.
device
==
self
.
col
.
device
assert
value
.
size
(
0
)
==
self
.
col
.
numel
()
#
if torch.is_tensor(value):
#
value = value if dtype is None else value.to(dtype)
#
assert value.device == self.col.device
#
assert value.size(0) == self.col.numel()
return
self
.
__class__
(
row
=
self
.
_row
,
rowptr
=
self
.
_rowptr
,
col
=
self
.
col
,
value
=
value
,
sparse_size
=
self
.
_sparse_size
,
rowcount
=
self
.
_rowcount
,
colptr
=
self
.
_colptr
,
colcount
=
self
.
_colcount
,
csr2csc
=
self
.
_csr2csc
,
csc2csr
=
self
.
_csc2csr
,
is_sorted
=
True
)
#
return self.__class__(row=self._row, rowptr=self._rowptr, col=self.col,
#
value=value, sparse_size=self._sparse_size,
#
rowcount=self._rowcount, colptr=self._colptr,
#
colcount=self._colcount, csr2csc=self._csr2csc,
#
csc2csr=self._csc2csr, is_sorted=True)
@
property
def
sparse_size
(
self
):
def
sparse_size
(
self
)
->
List
[
int
]:
return
self
.
_sparse_size
def
sparse_resize
(
self
,
*
sizes
):
old_sparse_size
,
nnz
=
self
.
sparse_size
,
self
.
col
.
numel
()
diff_0
=
sizes
[
0
]
-
old_sparse_size
[
0
]
rowcount
,
rowptr
=
self
.
_rowcount
,
self
.
_rowptr
if
diff_0
>
0
:
if
rowptr
is
not
None
:
rowptr
=
torch
.
cat
([
rowptr
,
rowptr
.
new_full
((
diff_0
,
),
nnz
)])
if
rowcount
is
not
None
:
rowcount
=
torch
.
cat
([
rowcount
,
rowcount
.
new_zeros
(
diff_0
)])
else
:
if
rowptr
is
not
None
:
rowptr
=
rowptr
[:
-
diff_0
]
if
rowcount
is
not
None
:
rowcount
=
rowcount
[:
-
diff_0
]
diff_1
=
sizes
[
1
]
-
old_sparse_size
[
1
]
colcount
,
colptr
=
self
.
_colcount
,
self
.
_colptr
if
diff_1
>
0
:
if
colptr
is
not
None
:
colptr
=
torch
.
cat
([
colptr
,
colptr
.
new_full
((
diff_1
,
),
nnz
)])
if
colcount
is
not
None
:
colcount
=
torch
.
cat
([
colcount
,
colcount
.
new_zeros
(
diff_1
)])
else
:
if
colptr
is
not
None
:
colptr
=
colptr
[:
-
diff_1
]
if
colcount
is
not
None
:
colcount
=
colcount
[:
-
diff_1
]
return
self
.
__class__
(
row
=
self
.
_row
,
rowptr
=
rowptr
,
col
=
self
.
col
,
value
=
self
.
value
,
sparse_size
=
sizes
,
rowcount
=
rowcount
,
colptr
=
colptr
,
colcount
=
colcount
,
csr2csc
=
self
.
_csr2csc
,
csc2csr
=
self
.
_csc2csr
,
is_sorted
=
True
)
def
has_rowcount
(
self
):
#
def sparse_resize(self, *sizes):
#
old_sparse_size, nnz = self.sparse_size, self.col.numel()
#
diff_0 = sizes[0] - old_sparse_size[0]
#
rowcount, rowptr = self._rowcount, self._rowptr
#
if diff_0 > 0:
#
if rowptr is not None:
#
rowptr = torch.cat([rowptr, rowptr.new_full((diff_0, ), nnz)])
#
if rowcount is not None:
#
rowcount = torch.cat([rowcount, rowcount.new_zeros(diff_0)])
#
else:
#
if rowptr is not None:
#
rowptr = rowptr[:-diff_0]
#
if rowcount is not None:
#
rowcount = rowcount[:-diff_0]
#
diff_1 = sizes[1] - old_sparse_size[1]
#
colcount, colptr = self._colcount, self._colptr
#
if diff_1 > 0:
#
if colptr is not None:
#
colptr = torch.cat([colptr, colptr.new_full((diff_1, ), nnz)])
#
if colcount is not None:
#
colcount = torch.cat([colcount, colcount.new_zeros(diff_1)])
#
else:
#
if colptr is not None:
#
colptr = colptr[:-diff_1]
#
if colcount is not None:
#
colcount = colcount[:-diff_1]
#
return self.__class__(row=self._row, rowptr=rowptr, col=self.col,
#
value=self.value, sparse_size=sizes,
#
rowcount=rowcount, colptr=colptr,
#
colcount=colcount, csr2csc=self._csr2csc,
#
csc2csr=self._csc2csr, is_sorted=True)
def
has_rowcount
(
self
)
->
bool
:
return
self
.
_rowcount
is
not
None
@
cached_property
def
rowcount
(
self
):
return
self
.
rowptr
[
1
:]
-
self
.
rowptr
[:
-
1
]
def
rowcount
(
self
)
->
torch
.
Tensor
:
rowcount
=
self
.
_rowcount
if
rowcount
is
not
None
:
return
rowcount
rowptr
=
self
.
rowptr
()
rowcount
=
rowptr
[
1
:]
-
rowptr
[
1
:]
self
.
_rowcount
=
rowcount
return
rowcount
def
has_colptr
(
self
):
def
has_colptr
(
self
)
->
bool
:
return
self
.
_colptr
is
not
None
@
cached_property
def
colptr
(
self
):
if
self
.
has_csr2csc
():
return
ext
(
self
.
col
.
is_cuda
).
ind2ptr
(
self
.
col
[
self
.
csr2csc
],
self
.
sparse_size
[
1
])
else
:
colptr
=
self
.
col
.
new_zeros
(
self
.
sparse_size
[
1
]
+
1
)
torch
.
cumsum
(
self
.
colcount
,
dim
=
0
,
out
=
colptr
[
1
:])
def
colptr
(
self
)
->
torch
.
Tensor
:
colptr
=
self
.
_colptr
if
colptr
is
not
None
:
return
colptr
def
has_colcount
(
self
):
csr2csc
=
self
.
_csr2csc
if
csr2csc
is
not
None
:
colptr
=
torch
.
ops
.
torch_sparse_cpu
.
ind2ptr
(
self
.
_col
[
csr2csc
],
self
.
_sparse_size
[
1
])
else
:
colptr
=
self
.
_col
.
new_zeros
(
self
.
_sparse_size
[
1
]
+
1
)
torch
.
cumsum
(
self
.
colcount
(),
dim
=
0
,
out
=
colptr
[
1
:])
self
.
_colptr
=
colptr
return
colptr
def
has_colcount
(
self
)
->
bool
:
return
self
.
_colcount
is
not
None
@
cached_property
def
colcount
(
self
):
if
self
.
has_colptr
():
return
self
.
colptr
[
1
:]
-
self
.
colptr
[:
-
1
]
def
colcount
(
self
)
->
torch
.
Tensor
:
colcount
=
self
.
_colcount
if
colcount
is
not
None
:
return
colcount
colptr
=
self
.
_colptr
if
colptr
is
not
None
:
colcount
=
colptr
[
1
:]
-
colptr
[
1
:]
else
:
return
scatter_add
(
torch
.
ones_like
(
self
.
col
),
self
.
col
,
dim_size
=
self
.
sparse_size
[
1
])
raise
NotImplementedError
# colcount = scatter_add(torch.ones_like(self._col), self._col,
# dim_size=self._sparse_size[1])
self
.
_colcount
=
colcount
return
colcount
def
has_csr2csc
(
self
):
def
has_csr2csc
(
self
)
->
bool
:
return
self
.
_csr2csc
is
not
None
@
cached_property
def
csr2csc
(
self
):
i
dx
=
self
.
sparse_size
[
0
]
*
self
.
col
+
self
.
row
return
idx
.
argsort
()
def
csr2csc
(
self
)
->
torch
.
Tensor
:
csr2csc
=
self
.
_csr2csc
i
f
csr2csc
is
not
None
:
return
csr2csc
def
has_csc2csr
(
self
):
idx
=
self
.
_sparse_size
[
0
]
*
self
.
_col
+
self
.
row
()
csr2csc
=
idx
.
argsort
()
self
.
_csr2csc
=
csr2csc
return
csr2csc
def
has_csc2csr
(
self
)
->
bool
:
return
self
.
_csc2csr
is
not
None
@
cached_property
def
csc2csr
(
self
):
return
self
.
csr2csc
.
argsort
()
def
csc2csr
(
self
)
->
torch
.
Tensor
:
csc2csr
=
self
.
_csc2csr
if
csc2csr
is
not
None
:
return
csc2csr
def
is_coalesced
(
self
):
idx
=
self
.
col
.
new_full
((
self
.
col
.
numel
()
+
1
,
),
-
1
)
idx
[
1
:]
=
self
.
sparse_size
[
1
]
*
self
.
row
+
self
.
col
return
(
idx
[
1
:]
>
idx
[:
-
1
]).
all
().
item
()
csc2csr
=
self
.
csr2csc
().
argsort
()
self
.
_csc2csr
=
csc2csr
return
csc2csr
def
coalesce
(
self
,
reduce
=
'add'
):
idx
=
self
.
col
.
new_full
((
self
.
col
.
numel
()
+
1
,
),
-
1
)
idx
[
1
:]
=
self
.
sparse_size
[
1
]
*
self
.
row
+
self
.
col
def
is_coalesced
(
self
)
->
bool
:
idx
=
self
.
_col
.
new_full
((
self
.
_col
.
numel
()
+
1
,
),
-
1
)
idx
[
1
:]
=
self
.
_sparse_size
[
1
]
*
self
.
row
()
+
self
.
_col
return
bool
((
idx
[
1
:]
>
idx
[:
-
1
]).
all
())
def
coalesce
(
self
,
reduce
:
str
=
"add"
):
idx
=
self
.
_col
.
new_full
((
self
.
_col
.
numel
()
+
1
,
),
-
1
)
idx
[
1
:]
=
self
.
_sparse_size
[
1
]
*
self
.
row
()
+
self
.
_col
mask
=
idx
[
1
:]
>
idx
[:
-
1
]
if
mask
.
all
():
# Skip if indices are already coalesced.
return
self
row
=
self
.
row
[
mask
]
col
=
self
.
col
[
mask
]
row
=
self
.
row
()
[
mask
]
col
=
self
.
_
col
[
mask
]
value
=
self
.
value
if
self
.
has_value
()
:
value
=
self
.
_
value
if
value
is
not
None
:
ptr
=
mask
.
nonzero
().
flatten
()
ptr
=
torch
.
cat
([
ptr
,
ptr
.
new_full
((
1
,
),
value
.
size
(
0
))])
value
=
segment_csr
(
value
,
ptr
,
reduce
=
reduce
)
raise
NotImplementedError
# value = segment_csr(value, ptr, reduce=reduce)
value
=
value
[
0
]
if
isinstance
(
value
,
tuple
)
else
value
return
self
.
__class__
(
row
=
row
,
col
=
col
,
value
=
value
,
sparse_size
=
self
.
sparse_size
,
is_sorted
=
True
)
def
cached_keys
(
self
):
return
[
key
for
key
in
self
.
cache_keys
if
getattr
(
self
,
f
'_
{
key
}
'
,
None
)
is
not
None
]
def
fill_cache_
(
self
,
*
args
):
for
arg
in
args
or
self
.
cache_keys
+
[
'row'
,
'rowptr'
]:
getattr
(
self
,
arg
)
return
SparseStorage
(
row
=
row
,
rowptr
=
None
,
col
=
col
,
value
=
value
,
sparse_size
=
self
.
_sparse_size
,
rowcount
=
None
,
colptr
=
None
,
colcount
=
None
,
csr2csc
=
None
,
csc2csr
=
None
,
is_sorted
=
True
)
def
fill_cache_
(
self
):
self
.
row
()
self
.
rowptr
()
self
.
rowcount
()
self
.
colptr
()
self
.
colcount
()
self
.
csr2csc
()
self
.
csc2csr
()
return
self
def
clear_cache_
(
self
,
*
args
):
for
arg
in
args
or
self
.
cache_keys
:
setattr
(
self
,
f
'_
{
arg
}
'
,
None
)
def
clear_cache_
(
self
):
self
.
_rowcount
=
None
self
.
_colptr
=
None
self
.
_colcount
=
None
self
.
_csr2csc
=
None
self
.
_csc2csr
=
None
return
self
def
__copy__
(
self
):
return
self
.
apply
(
lambda
x
:
x
)
return
SparseStorage
(
row
=
self
.
_row
,
rowptr
=
self
.
_rowptr
,
col
=
self
.
_col
,
value
=
self
.
_value
,
sparse_size
=
self
.
_sparse_size
,
rowcount
=
self
.
_rowcount
,
colptr
=
self
.
_colptr
,
colcount
=
self
.
_colcount
,
csr2csc
=
self
.
_csr2csc
,
csc2csr
=
self
.
_csc2csr
,
is_sorted
=
True
)
def
clone
(
self
):
return
self
.
apply
(
lambda
x
:
x
.
clone
())
def
__deepcopy__
(
self
,
memo
):
new_storage
=
self
.
clone
()
memo
[
id
(
self
)]
=
new_storage
return
new_storage
def
apply_value_
(
self
,
func
):
self
.
_value
=
optional
(
func
,
self
.
value
)
return
self
def
apply_value
(
self
,
func
):
return
self
.
__class__
(
row
=
self
.
_row
,
rowptr
=
self
.
_rowptr
,
col
=
self
.
col
,
value
=
optional
(
func
,
self
.
value
),
sparse_size
=
self
.
sparse_size
,
rowcount
=
self
.
_rowcount
,
colptr
=
self
.
_colptr
,
colcount
=
self
.
_colcount
,
csr2csc
=
self
.
_csr2csc
,
csc2csr
=
self
.
_csc2csr
,
is_sorted
=
True
)
def
apply_
(
self
,
func
):
self
.
_row
=
optional
(
func
,
self
.
_row
)
self
.
_rowptr
=
optional
(
func
,
self
.
_rowptr
)
self
.
_col
=
func
(
self
.
col
)
self
.
_value
=
optional
(
func
,
self
.
value
)
for
key
in
self
.
cached_keys
():
setattr
(
self
,
f
'_
{
key
}
'
,
func
(
getattr
(
self
,
f
'_
{
key
}
'
)))
return
self
def
apply
(
self
,
func
):
return
self
.
__class__
(
row
=
optional
(
func
,
self
.
_row
),
rowptr
=
optional
(
func
,
self
.
_rowptr
),
col
=
func
(
self
.
col
),
value
=
optional
(
func
,
self
.
value
),
sparse_size
=
self
.
sparse_size
,
rowcount
=
optional
(
func
,
self
.
_rowcount
),
colptr
=
optional
(
func
,
self
.
_colptr
),
colcount
=
optional
(
func
,
self
.
_colcount
),
csr2csc
=
optional
(
func
,
self
.
_csr2csc
),
csc2csr
=
optional
(
func
,
self
.
_csc2csr
),
is_sorted
=
True
,
)
def
map
(
self
,
func
):
data
=
[]
if
self
.
has_row
():
data
+=
[
func
(
self
.
row
)]
if
self
.
has_rowptr
():
data
+=
[
func
(
self
.
rowptr
)]
data
+=
[
func
(
self
.
col
)]
if
self
.
has_value
():
data
+=
[
func
(
self
.
value
)]
data
+=
[
func
(
getattr
(
self
,
f
'_
{
key
}
'
))
for
key
in
self
.
cached_keys
()]
return
data
row
=
self
.
_row
if
row
is
not
None
:
row
=
row
.
clone
()
rowptr
=
self
.
_rowptr
if
rowptr
is
not
None
:
rowptr
=
rowptr
.
clone
()
value
=
self
.
_value
if
value
is
not
None
:
value
=
value
.
clone
()
rowcount
=
self
.
_rowcount
if
rowcount
is
not
None
:
rowcount
=
rowcount
.
clone
()
colptr
=
self
.
_colptr
if
colptr
is
not
None
:
colptr
=
colptr
.
clone
()
colcount
=
self
.
_colcount
if
colcount
is
not
None
:
colcount
=
colcount
.
clone
()
csr2csc
=
self
.
_csr2csc
if
csr2csc
is
not
None
:
csr2csc
=
csr2csc
.
clone
()
csc2csr
=
self
.
_csc2csr
if
csc2csr
is
not
None
:
csc2csr
=
csc2csr
.
clone
()
return
SparseStorage
(
row
=
row
,
rowptr
=
rowptr
,
col
=
self
.
_col
.
clone
(),
value
=
value
,
sparse_size
=
self
.
_sparse_size
,
rowcount
=
rowcount
,
colptr
=
colptr
,
colcount
=
colcount
,
csr2csc
=
csr2csc
,
csc2csr
=
csc2csr
,
is_sorted
=
True
)
def
__deepcopy__
(
self
,
memo
:
Dict
[
str
,
Any
]):
return
self
.
clone
()
torch_sparse/utils.py
View file @
f59fe649
from
typing
import
Any
import
torch
try
:
from
typing_extensions
import
Final
# noqa
except
ImportError
:
from
torch.jit
import
Final
# noqa
torch
.
ops
.
load_library
(
'torch_sparse/convert_cpu.so'
)
torch
.
ops
.
load_library
(
'torch_sparse/diag_cpu.so'
)
torch
.
ops
.
load_library
(
'torch_sparse/spmm_cpu.so'
)
...
...
@@ -14,10 +21,5 @@ except OSError as e:
raise
e
def
ext
(
is_cuda
):
name
=
'torch_sparse_cuda'
if
is_cuda
else
'torch_sparse_cpu'
return
getattr
(
torch
.
ops
,
name
)
def
is_scalar
(
other
):
def
is_scalar
(
other
:
Any
)
->
bool
:
return
isinstance
(
other
,
int
)
or
isinstance
(
other
,
float
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment