Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
f59fe649
"tests/vscode:/vscode.git/clone" did not exist on "86c243b45f0d1652a476d9c5ac165f22bf95c91e"
Commit
f59fe649
authored
Jan 27, 2020
by
rusty1s
Browse files
beginning of torch script support
parent
c4484dbb
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
391 additions
and
233 deletions
+391
-233
test/test_jit.py
test/test_jit.py
+80
-0
torch_sparse/coalesce.py
torch_sparse/coalesce.py
+2
-1
torch_sparse/diag.py
torch_sparse/diag.py
+7
-4
torch_sparse/matmul.py
torch_sparse/matmul.py
+8
-2
torch_sparse/spspmm.py
torch_sparse/spspmm.py
+4
-3
torch_sparse/storage.py
torch_sparse/storage.py
+282
-217
torch_sparse/utils.py
torch_sparse/utils.py
+8
-6
No files found.
test/test_jit.py
0 → 100644
View file @
f59fe649
import
torch
from
torch_sparse.tensor
import
SparseTensor
from
torch_sparse.storage
import
SparseStorage
from
typing
import
Dict
,
Any
# class MyTensor(dict):
# def __init__(self, rowptr, col):
# self['rowptr'] = rowptr
# self['col'] = col
# def rowptr(self: Dict[str, torch.Tensor]):
# return self['rowptr']
@
torch
.
jit
.
script
class
Foo
:
rowptr
:
torch
.
Tensor
col
:
torch
.
Tensor
def
__init__
(
self
,
rowptr
:
torch
.
Tensor
,
col
:
torch
.
Tensor
):
self
.
rowptr
=
rowptr
self
.
col
=
col
class
MyCell
(
torch
.
nn
.
Module
):
def
__init__
(
self
):
super
(
MyCell
,
self
).
__init__
()
self
.
linear
=
torch
.
nn
.
Linear
(
2
,
4
)
# def forward(self, x: torch.Tensor, ptr: torch.Tensor) -> torch.Tensor:
def
forward
(
self
,
x
:
torch
.
Tensor
,
adj
:
SparseStorage
)
->
torch
.
Tensor
:
out
,
_
=
torch
.
ops
.
torch_sparse_cpu
.
spmm
(
adj
.
rowptr
(),
adj
.
col
(),
None
,
x
,
'sum'
)
return
out
# ind = torch.ops.torch_sparse_cpu.ptr2ind(ptr, ptr[-1].item())
# # ind = ptr2ind(ptr, E)
# x_j = x[ind]
# out = self.linear(x_j)
# return out
def
test_jit
():
my_cell
=
MyCell
()
# x = torch.rand(3, 2)
# ptr = torch.tensor([0, 2, 4, 6])
# out = my_cell(x, ptr)
# print()
# print(out)
# traced_cell = torch.jit.trace(my_cell, (x, ptr))
# print(traced_cell)
# out = traced_cell(x, ptr)
# print(out)
x
=
torch
.
randn
(
3
,
2
)
# adj = torch.randn(3, 3)
# adj = SparseTensor.from_dense(adj)
# adj = Foo(adj.storage.rowptr, adj.storage.col)
# adj = adj.storage
rowptr
=
torch
.
tensor
([
0
,
3
,
6
,
9
])
col
=
torch
.
tensor
([
0
,
1
,
2
,
0
,
1
,
2
,
0
,
1
,
2
])
adj
=
SparseStorage
(
rowptr
=
rowptr
,
col
=
col
)
# adj = {'rowptr': mat.storage.rowptr, 'col': mat.storage.col}
# foo = Foo(mat.storage.rowptr, mat.storage.col)
# adj = MyTensor(mat.storage.rowptr, mat.storage.col)
traced_cell
=
torch
.
jit
.
script
(
my_cell
)
print
(
traced_cell
)
out
=
traced_cell
(
x
,
adj
)
print
(
out
)
# # print(traced_cell.code)
torch_sparse/coalesce.py
View file @
f59fe649
import
torch
import
torch
import
torch_scatter
import
torch_scatter
from
.unique
import
unique
#
from .unique import unique
def
coalesce
(
index
,
value
,
m
,
n
,
op
=
'add'
,
fill_value
=
0
):
def
coalesce
(
index
,
value
,
m
,
n
,
op
=
'add'
,
fill_value
=
0
):
...
@@ -22,6 +22,7 @@ def coalesce(index, value, m, n, op='add', fill_value=0):
...
@@ -22,6 +22,7 @@ def coalesce(index, value, m, n, op='add', fill_value=0):
:rtype: (:class:`LongTensor`, :class:`Tensor`)
:rtype: (:class:`LongTensor`, :class:`Tensor`)
"""
"""
raise
NotImplementedError
row
,
col
=
index
row
,
col
=
index
...
...
torch_sparse/diag.py
View file @
f59fe649
import
torch
import
torch
from
torch_sparse.utils
import
ext
def
remove_diag
(
src
,
k
=
0
):
def
remove_diag
(
src
,
k
=
0
):
row
,
col
,
value
=
src
.
coo
()
row
,
col
,
value
=
src
.
coo
()
...
@@ -39,8 +37,13 @@ def set_diag(src, values=None, k=0):
...
@@ -39,8 +37,13 @@ def set_diag(src, values=None, k=0):
row
,
col
,
value
=
src
.
coo
()
row
,
col
,
value
=
src
.
coo
()
mask
=
ext
(
row
.
is_cuda
).
non_diag_mask
(
row
,
col
,
src
.
size
(
0
),
src
.
size
(
1
),
if
row
.
is_cuda
:
k
)
mask
=
torch
.
ops
.
torch_sparse_cuda
.
non_diag_mask
(
row
,
col
,
src
.
size
(
0
),
src
.
size
(
1
),
k
)
else
:
mask
=
torch
.
ops
.
torch_sparse_cpu
.
non_diag_mask
(
row
,
col
,
src
.
size
(
0
),
src
.
size
(
1
),
k
)
inv_mask
=
~
mask
inv_mask
=
~
mask
start
,
num_diag
=
-
k
if
k
<
0
else
0
,
mask
.
numel
()
-
row
.
numel
()
start
,
num_diag
=
-
k
if
k
<
0
else
0
,
mask
.
numel
()
-
row
.
numel
()
...
...
torch_sparse/matmul.py
View file @
f59fe649
import
torch
import
torch
import
scipy.sparse
import
scipy.sparse
from
torch_scatter
import
scatter_add
from
torch_scatter
import
scatter_add
from
torch_sparse.utils
import
ext
ext
=
None
class
SPMM
(
torch
.
autograd
.
Function
):
class
SPMM
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
row
,
rowptr
,
col
,
value
,
mat
,
rowcount
,
colptr
,
csr2csc
,
def
forward
(
ctx
,
row
,
rowptr
,
col
,
value
,
mat
,
rowcount
,
colptr
,
csr2csc
,
reduce
):
reduce
):
out
,
arg_out
=
ext
(
mat
.
is_cuda
).
spmm
(
rowptr
,
col
,
value
,
mat
,
reduce
)
if
mat
.
is_cuda
:
out
,
arg_out
=
torch
.
ops
.
torch_sparse_cuda
.
spmm
(
rowptr
,
col
,
value
,
mat
,
reduce
)
else
:
out
,
arg_out
=
torch
.
ops
.
torch_sparse_cpu
.
spmm
(
rowptr
,
col
,
value
,
mat
,
reduce
)
ctx
.
reduce
=
reduce
ctx
.
reduce
=
reduce
ctx
.
save_for_backward
(
row
,
rowptr
,
col
,
value
,
mat
,
rowcount
,
colptr
,
ctx
.
save_for_backward
(
row
,
rowptr
,
col
,
value
,
mat
,
rowcount
,
colptr
,
...
...
torch_sparse/spspmm.py
View file @
f59fe649
import
torch
import
torch
from
torch_sparse
import
transpose
,
to_scipy
,
from_scipy
,
coalesce
from
torch_sparse
import
transpose
,
to_scipy
,
from_scipy
,
coalesce
import
torch_sparse.spspmm_cpu
#
import torch_sparse.spspmm_cpu
if
torch
.
cuda
.
is_available
():
#
if torch.cuda.is_available():
import
torch_sparse.spspmm_cuda
#
import torch_sparse.spspmm_cuda
def
spspmm
(
indexA
,
valueA
,
indexB
,
valueB
,
m
,
k
,
n
,
coalesced
=
False
):
def
spspmm
(
indexA
,
valueA
,
indexB
,
valueB
,
m
,
k
,
n
,
coalesced
=
False
):
...
@@ -25,6 +25,7 @@ def spspmm(indexA, valueA, indexB, valueB, m, k, n, coalesced=False):
...
@@ -25,6 +25,7 @@ def spspmm(indexA, valueA, indexB, valueB, m, k, n, coalesced=False):
:rtype: (:class:`LongTensor`, :class:`Tensor`)
:rtype: (:class:`LongTensor`, :class:`Tensor`)
"""
"""
raise
NotImplementedError
if
indexA
.
is_cuda
and
coalesced
:
if
indexA
.
is_cuda
and
coalesced
:
indexA
,
valueA
=
coalesce
(
indexA
,
valueA
,
m
,
k
)
indexA
,
valueA
=
coalesce
(
indexA
,
valueA
,
m
,
k
)
indexB
,
valueB
=
coalesce
(
indexB
,
valueB
,
k
,
n
)
indexB
,
valueB
=
coalesce
(
indexB
,
valueB
,
k
,
n
)
...
...
torch_sparse/storage.py
View file @
f59fe649
import
warnings
import
warnings
from
typing
import
Optional
,
List
,
Dict
,
Any
import
torch
import
torch
from
torch_scatter
import
segment_csr
,
scatter_add
from
torch_scatter
import
segment_csr
,
scatter_add
from
torch_sparse.utils
import
ext
from
torch_sparse.utils
import
Final
__cache__
=
{
'enabled'
:
True
}
__cache__
=
{
'enabled'
:
True
}
...
@@ -32,24 +33,24 @@ class no_cache(object):
...
@@ -32,24 +33,24 @@ class no_cache(object):
return
decorate_no_cache
return
decorate_no_cache
class
cached_property
(
object
):
#
class cached_property(object):
def
__init__
(
self
,
func
):
#
def __init__(self, func):
self
.
func
=
func
#
self.func = func
def
__get__
(
self
,
obj
,
cls
):
#
def __get__(self, obj, cls):
value
=
getattr
(
obj
,
f
'_
{
self
.
func
.
__name__
}
'
,
None
)
#
value = getattr(obj, f'_{self.func.__name__}', None)
if
value
is
None
:
#
if value is None:
value
=
self
.
func
(
obj
)
#
value = self.func(obj)
if
is_cache_enabled
():
#
if is_cache_enabled():
setattr
(
obj
,
f
'_
{
self
.
func
.
__name__
}
'
,
value
)
#
setattr(obj, f'_{self.func.__name__}', value)
return
value
#
return value
def
optional
(
func
,
src
):
def
optional
(
func
,
src
):
return
func
(
src
)
if
src
is
not
None
else
src
return
func
(
src
)
if
src
is
not
None
else
src
layouts
=
[
'coo'
,
'csr'
,
'csc'
]
layouts
:
Final
[
List
[
str
]]
=
[
'coo'
,
'csr'
,
'csc'
]
def
get_layout
(
layout
=
None
):
def
get_layout
(
layout
=
None
):
...
@@ -61,12 +62,30 @@ def get_layout(layout=None):
...
@@ -61,12 +62,30 @@ def get_layout(layout=None):
return
layout
return
layout
@
torch
.
jit
.
script
class
SparseStorage
(
object
):
class
SparseStorage
(
object
):
cache_keys
=
[
'rowcount'
,
'colptr'
,
'colcount'
,
'csr2csc'
,
'csc2csr'
]
_row
:
Optional
[
torch
.
Tensor
]
_rowptr
:
Optional
[
torch
.
Tensor
]
def
__init__
(
self
,
row
=
None
,
rowptr
=
None
,
col
=
None
,
value
=
None
,
_col
:
torch
.
Tensor
sparse_size
=
None
,
rowcount
=
None
,
colptr
=
None
,
colcount
=
None
,
_value
:
Optional
[
torch
.
Tensor
]
csr2csc
=
None
,
csc2csr
=
None
,
is_sorted
=
False
):
_sparse_size
:
List
[
int
]
_rowcount
:
Optional
[
torch
.
Tensor
]
_colptr
:
Optional
[
torch
.
Tensor
]
_colcount
:
Optional
[
torch
.
Tensor
]
_csr2csc
:
Optional
[
torch
.
Tensor
]
_csc2csr
:
Optional
[
torch
.
Tensor
]
def
__init__
(
self
,
row
:
Optional
[
torch
.
Tensor
]
=
None
,
rowptr
:
Optional
[
torch
.
Tensor
]
=
None
,
col
:
Optional
[
torch
.
Tensor
]
=
None
,
value
:
Optional
[
torch
.
Tensor
]
=
None
,
sparse_size
:
Optional
[
List
[
int
]]
=
None
,
rowcount
:
Optional
[
torch
.
Tensor
]
=
None
,
colptr
:
Optional
[
torch
.
Tensor
]
=
None
,
colcount
:
Optional
[
torch
.
Tensor
]
=
None
,
csr2csc
:
Optional
[
torch
.
Tensor
]
=
None
,
csc2csr
:
Optional
[
torch
.
Tensor
]
=
None
,
is_sorted
:
bool
=
False
):
assert
row
is
not
None
or
rowptr
is
not
None
assert
row
is
not
None
or
rowptr
is
not
None
assert
col
is
not
None
assert
col
is
not
None
...
@@ -75,9 +94,16 @@ class SparseStorage(object):
...
@@ -75,9 +94,16 @@ class SparseStorage(object):
col
=
col
.
contiguous
()
col
=
col
.
contiguous
()
if
sparse_size
is
None
:
if
sparse_size
is
None
:
M
=
rowptr
.
numel
()
-
1
if
row
is
None
else
row
.
max
().
item
()
+
1
if
rowptr
is
not
None
:
M
=
rowptr
.
numel
()
-
1
elif
row
is
not
None
:
M
=
row
.
max
().
item
()
+
1
else
:
raise
ValueError
N
=
col
.
max
().
item
()
+
1
N
=
col
.
max
().
item
()
+
1
sparse_size
=
torch
.
Size
([
M
,
N
])
sparse_size
=
torch
.
Size
([
int
(
M
),
int
(
N
)])
else
:
assert
len
(
sparse_size
)
==
2
if
row
is
not
None
:
if
row
is
not
None
:
assert
row
.
dtype
==
torch
.
long
assert
row
.
dtype
==
torch
.
long
...
@@ -145,264 +171,303 @@ class SparseStorage(object):
...
@@ -145,264 +171,303 @@ class SparseStorage(object):
self
.
_csc2csr
=
csc2csr
self
.
_csc2csr
=
csc2csr
if
not
is_sorted
:
if
not
is_sorted
:
idx
=
self
.
col
.
new_zeros
(
col
.
numel
()
+
1
)
idx
=
col
.
new_zeros
(
col
.
numel
()
+
1
)
idx
[
1
:]
=
sparse_size
[
1
]
*
self
.
row
+
self
.
col
idx
[
1
:]
=
sparse_size
[
1
]
*
self
.
row
()
+
col
if
(
idx
[
1
:]
<
idx
[:
-
1
]).
any
():
if
(
idx
[
1
:]
<
idx
[:
-
1
]).
any
():
perm
=
idx
[
1
:].
argsort
()
perm
=
idx
[
1
:].
argsort
()
self
.
_row
=
self
.
row
[
perm
]
self
.
_row
=
self
.
row
()[
perm
]
self
.
_col
=
self
.
col
[
perm
]
self
.
_col
=
col
[
perm
]
self
.
_value
=
self
.
value
[
perm
]
if
self
.
has_value
()
else
None
if
value
is
not
None
:
self
.
_value
=
value
[
perm
]
self
.
_csr2csc
=
None
self
.
_csr2csc
=
None
self
.
_csc2csr
=
None
self
.
_csc2csr
=
None
def
has_row
(
self
):
def
has_row
(
self
)
->
bool
:
return
self
.
_row
is
not
None
return
self
.
_row
is
not
None
@
property
def
row
(
self
):
def
row
(
self
):
if
self
.
_row
is
None
:
row
=
self
.
_row
self
.
_row
=
ext
(
self
.
col
.
is_cuda
).
ptr2ind
(
self
.
rowptr
,
if
row
is
not
None
:
self
.
col
.
numel
())
return
row
return
self
.
_row
def
has_rowptr
(
self
):
rowptr
=
self
.
_rowptr
if
rowptr
is
not
None
:
if
rowptr
.
is_cuda
:
row
=
torch
.
ops
.
torch_sparse_cuda
.
ptr2ind
(
rowptr
,
self
.
_col
.
numel
())
else
:
if
rowptr
.
is_cuda
:
row
=
torch
.
ops
.
torch_sparse_cuda
.
ptr2ind
(
rowptr
,
self
.
_col
.
numel
())
else
:
row
=
torch
.
ops
.
torch_sparse_cpu
.
ptr2ind
(
rowptr
,
self
.
_col
.
numel
())
self
.
_row
=
row
return
row
raise
ValueError
def
has_rowptr
(
self
)
->
bool
:
return
self
.
_rowptr
is
not
None
return
self
.
_rowptr
is
not
None
@
property
def
rowptr
(
self
)
->
torch
.
Tensor
:
def
rowptr
(
self
):
rowptr
=
self
.
_rowptr
if
self
.
_rowptr
is
None
:
if
rowptr
is
not
None
:
self
.
_rowptr
=
ext
(
self
.
col
.
is_cuda
).
ind2ptr
(
return
rowptr
self
.
row
,
self
.
sparse_size
[
0
])
return
self
.
_rowptr
@
property
row
=
self
.
_row
def
col
(
self
):
if
row
is
not
None
:
if
row
.
is_cuda
:
rowptr
=
torch
.
ops
.
torch_sparse_cuda
.
ind2ptr
(
row
,
self
.
_sparse_size
[
0
])
else
:
rowptr
=
torch
.
ops
.
torch_sparse_cpu
.
ind2ptr
(
row
,
self
.
_sparse_size
[
0
])
self
.
_rowptr
=
rowptr
return
rowptr
raise
ValueError
def
col
(
self
)
->
torch
.
Tensor
:
return
self
.
_col
return
self
.
_col
def
has_value
(
self
):
def
has_value
(
self
)
->
bool
:
return
self
.
_value
is
not
None
return
self
.
_value
is
not
None
@
property
def
value
(
self
)
->
Optional
[
torch
.
Tensor
]:
def
value
(
self
):
return
self
.
_value
return
self
.
_value
def
set_value_
(
self
,
value
,
layout
=
None
,
dtype
=
None
):
#
def set_value_(self, value, layout=None, dtype=None):
if
isinstance
(
value
,
int
)
or
isinstance
(
value
,
float
):
#
if isinstance(value, int) or isinstance(value, float):
value
=
torch
.
full
((
self
.
col
.
numel
(),
),
dtype
=
dtype
,
#
value = torch.full((self.col.numel(), ), dtype=dtype,
device
=
self
.
col
.
device
)
#
device=self.col.device)
elif
torch
.
is_tensor
(
value
)
and
get_layout
(
layout
)
==
'csc'
:
#
elif torch.is_tensor(value) and get_layout(layout) == 'csc':
value
=
value
[
self
.
csc2csr
]
#
value = value[self.csc2csr]
if
torch
.
is_tensor
(
value
):
#
if torch.is_tensor(value):
value
=
value
if
dtype
is
None
else
value
.
to
(
dtype
)
#
value = value if dtype is None else value.to(dtype)
assert
value
.
device
==
self
.
col
.
device
#
assert value.device == self.col.device
assert
value
.
size
(
0
)
==
self
.
col
.
numel
()
#
assert value.size(0) == self.col.numel()
self
.
_value
=
value
#
self._value = value
return
self
#
return self
def
set_value
(
self
,
value
,
layout
=
None
,
dtype
=
None
):
#
def set_value(self, value, layout=None, dtype=None):
if
isinstance
(
value
,
int
)
or
isinstance
(
value
,
float
):
#
if isinstance(value, int) or isinstance(value, float):
value
=
torch
.
full
((
self
.
col
.
numel
(),
),
dtype
=
dtype
,
#
value = torch.full((self.col.numel(), ), dtype=dtype,
device
=
self
.
col
.
device
)
#
device=self.col.device)
elif
torch
.
is_tensor
(
value
)
and
get_layout
(
layout
)
==
'csc'
:
#
elif torch.is_tensor(value) and get_layout(layout) == 'csc':
value
=
value
[
self
.
csc2csr
]
#
value = value[self.csc2csr]
if
torch
.
is_tensor
(
value
):
#
if torch.is_tensor(value):
value
=
value
if
dtype
is
None
else
value
.
to
(
dtype
)
#
value = value if dtype is None else value.to(dtype)
assert
value
.
device
==
self
.
col
.
device
#
assert value.device == self.col.device
assert
value
.
size
(
0
)
==
self
.
col
.
numel
()
#
assert value.size(0) == self.col.numel()
return
self
.
__class__
(
row
=
self
.
_row
,
rowptr
=
self
.
_rowptr
,
col
=
self
.
col
,
#
return self.__class__(row=self._row, rowptr=self._rowptr, col=self.col,
value
=
value
,
sparse_size
=
self
.
_sparse_size
,
#
value=value, sparse_size=self._sparse_size,
rowcount
=
self
.
_rowcount
,
colptr
=
self
.
_colptr
,
#
rowcount=self._rowcount, colptr=self._colptr,
colcount
=
self
.
_colcount
,
csr2csc
=
self
.
_csr2csc
,
#
colcount=self._colcount, csr2csc=self._csr2csc,
csc2csr
=
self
.
_csc2csr
,
is_sorted
=
True
)
#
csc2csr=self._csc2csr, is_sorted=True)
@
property
def
sparse_size
(
self
)
->
List
[
int
]:
def
sparse_size
(
self
):
return
self
.
_sparse_size
return
self
.
_sparse_size
def
sparse_resize
(
self
,
*
sizes
):
#
def sparse_resize(self, *sizes):
old_sparse_size
,
nnz
=
self
.
sparse_size
,
self
.
col
.
numel
()
#
old_sparse_size, nnz = self.sparse_size, self.col.numel()
diff_0
=
sizes
[
0
]
-
old_sparse_size
[
0
]
#
diff_0 = sizes[0] - old_sparse_size[0]
rowcount
,
rowptr
=
self
.
_rowcount
,
self
.
_rowptr
#
rowcount, rowptr = self._rowcount, self._rowptr
if
diff_0
>
0
:
#
if diff_0 > 0:
if
rowptr
is
not
None
:
#
if rowptr is not None:
rowptr
=
torch
.
cat
([
rowptr
,
rowptr
.
new_full
((
diff_0
,
),
nnz
)])
#
rowptr = torch.cat([rowptr, rowptr.new_full((diff_0, ), nnz)])
if
rowcount
is
not
None
:
#
if rowcount is not None:
rowcount
=
torch
.
cat
([
rowcount
,
rowcount
.
new_zeros
(
diff_0
)])
#
rowcount = torch.cat([rowcount, rowcount.new_zeros(diff_0)])
else
:
#
else:
if
rowptr
is
not
None
:
#
if rowptr is not None:
rowptr
=
rowptr
[:
-
diff_0
]
#
rowptr = rowptr[:-diff_0]
if
rowcount
is
not
None
:
#
if rowcount is not None:
rowcount
=
rowcount
[:
-
diff_0
]
#
rowcount = rowcount[:-diff_0]
diff_1
=
sizes
[
1
]
-
old_sparse_size
[
1
]
#
diff_1 = sizes[1] - old_sparse_size[1]
colcount
,
colptr
=
self
.
_colcount
,
self
.
_colptr
#
colcount, colptr = self._colcount, self._colptr
if
diff_1
>
0
:
#
if diff_1 > 0:
if
colptr
is
not
None
:
#
if colptr is not None:
colptr
=
torch
.
cat
([
colptr
,
colptr
.
new_full
((
diff_1
,
),
nnz
)])
#
colptr = torch.cat([colptr, colptr.new_full((diff_1, ), nnz)])
if
colcount
is
not
None
:
#
if colcount is not None:
colcount
=
torch
.
cat
([
colcount
,
colcount
.
new_zeros
(
diff_1
)])
#
colcount = torch.cat([colcount, colcount.new_zeros(diff_1)])
else
:
#
else:
if
colptr
is
not
None
:
#
if colptr is not None:
colptr
=
colptr
[:
-
diff_1
]
#
colptr = colptr[:-diff_1]
if
colcount
is
not
None
:
#
if colcount is not None:
colcount
=
colcount
[:
-
diff_1
]
#
colcount = colcount[:-diff_1]
return
self
.
__class__
(
row
=
self
.
_row
,
rowptr
=
rowptr
,
col
=
self
.
col
,
#
return self.__class__(row=self._row, rowptr=rowptr, col=self.col,
value
=
self
.
value
,
sparse_size
=
sizes
,
#
value=self.value, sparse_size=sizes,
rowcount
=
rowcount
,
colptr
=
colptr
,
#
rowcount=rowcount, colptr=colptr,
colcount
=
colcount
,
csr2csc
=
self
.
_csr2csc
,
#
colcount=colcount, csr2csc=self._csr2csc,
csc2csr
=
self
.
_csc2csr
,
is_sorted
=
True
)
#
csc2csr=self._csc2csr, is_sorted=True)
def
has_rowcount
(
self
):
def
has_rowcount
(
self
)
->
bool
:
return
self
.
_rowcount
is
not
None
return
self
.
_rowcount
is
not
None
@
cached_property
def
rowcount
(
self
)
->
torch
.
Tensor
:
def
rowcount
(
self
):
rowcount
=
self
.
_rowcount
return
self
.
rowptr
[
1
:]
-
self
.
rowptr
[:
-
1
]
if
rowcount
is
not
None
:
return
rowcount
rowptr
=
self
.
rowptr
()
rowcount
=
rowptr
[
1
:]
-
rowptr
[
1
:]
self
.
_rowcount
=
rowcount
return
rowcount
def
has_colptr
(
self
):
def
has_colptr
(
self
)
->
bool
:
return
self
.
_colptr
is
not
None
return
self
.
_colptr
is
not
None
@
cached_property
def
colptr
(
self
)
->
torch
.
Tensor
:
def
colptr
(
self
):
colptr
=
self
.
_colptr
if
self
.
has_csr2csc
():
if
colptr
is
not
None
:
return
ext
(
self
.
col
.
is_cuda
).
ind2ptr
(
self
.
col
[
self
.
csr2csc
],
self
.
sparse_size
[
1
])
else
:
colptr
=
self
.
col
.
new_zeros
(
self
.
sparse_size
[
1
]
+
1
)
torch
.
cumsum
(
self
.
colcount
,
dim
=
0
,
out
=
colptr
[
1
:])
return
colptr
return
colptr
def
has_colcount
(
self
):
csr2csc
=
self
.
_csr2csc
if
csr2csc
is
not
None
:
colptr
=
torch
.
ops
.
torch_sparse_cpu
.
ind2ptr
(
self
.
_col
[
csr2csc
],
self
.
_sparse_size
[
1
])
else
:
colptr
=
self
.
_col
.
new_zeros
(
self
.
_sparse_size
[
1
]
+
1
)
torch
.
cumsum
(
self
.
colcount
(),
dim
=
0
,
out
=
colptr
[
1
:])
self
.
_colptr
=
colptr
return
colptr
def
has_colcount
(
self
)
->
bool
:
return
self
.
_colcount
is
not
None
return
self
.
_colcount
is
not
None
@
cached_property
def
colcount
(
self
)
->
torch
.
Tensor
:
def
colcount
(
self
):
colcount
=
self
.
_colcount
if
self
.
has_colptr
():
if
colcount
is
not
None
:
return
self
.
colptr
[
1
:]
-
self
.
colptr
[:
-
1
]
return
colcount
colptr
=
self
.
_colptr
if
colptr
is
not
None
:
colcount
=
colptr
[
1
:]
-
colptr
[
1
:]
else
:
else
:
return
scatter_add
(
torch
.
ones_like
(
self
.
col
),
self
.
col
,
raise
NotImplementedError
dim_size
=
self
.
sparse_size
[
1
])
# colcount = scatter_add(torch.ones_like(self._col), self._col,
# dim_size=self._sparse_size[1])
self
.
_colcount
=
colcount
return
colcount
def
has_csr2csc
(
self
):
def
has_csr2csc
(
self
)
->
bool
:
return
self
.
_csr2csc
is
not
None
return
self
.
_csr2csc
is
not
None
@
cached_property
def
csr2csc
(
self
)
->
torch
.
Tensor
:
def
csr2csc
(
self
):
csr2csc
=
self
.
_csr2csc
i
dx
=
self
.
sparse_size
[
0
]
*
self
.
col
+
self
.
row
i
f
csr2csc
is
not
None
:
return
idx
.
argsort
()
return
csr2csc
def
has_csc2csr
(
self
):
idx
=
self
.
_sparse_size
[
0
]
*
self
.
_col
+
self
.
row
()
csr2csc
=
idx
.
argsort
()
self
.
_csr2csc
=
csr2csc
return
csr2csc
def
has_csc2csr
(
self
)
->
bool
:
return
self
.
_csc2csr
is
not
None
return
self
.
_csc2csr
is
not
None
@
cached_property
def
csc2csr
(
self
)
->
torch
.
Tensor
:
def
csc2csr
(
self
):
csc2csr
=
self
.
_csc2csr
return
self
.
csr2csc
.
argsort
()
if
csc2csr
is
not
None
:
return
csc2csr
def
is_coalesced
(
self
):
csc2csr
=
self
.
csr2csc
().
argsort
()
idx
=
self
.
col
.
new_full
((
self
.
col
.
numel
()
+
1
,
),
-
1
)
self
.
_csc2csr
=
csc2csr
idx
[
1
:]
=
self
.
sparse_size
[
1
]
*
self
.
row
+
self
.
col
return
csc2csr
return
(
idx
[
1
:]
>
idx
[:
-
1
]).
all
().
item
()
def
coalesce
(
self
,
reduce
=
'add'
):
def
is_coalesced
(
self
)
->
bool
:
idx
=
self
.
col
.
new_full
((
self
.
col
.
numel
()
+
1
,
),
-
1
)
idx
=
self
.
_col
.
new_full
((
self
.
_col
.
numel
()
+
1
,
),
-
1
)
idx
[
1
:]
=
self
.
sparse_size
[
1
]
*
self
.
row
+
self
.
col
idx
[
1
:]
=
self
.
_sparse_size
[
1
]
*
self
.
row
()
+
self
.
_col
return
bool
((
idx
[
1
:]
>
idx
[:
-
1
]).
all
())
def
coalesce
(
self
,
reduce
:
str
=
"add"
):
idx
=
self
.
_col
.
new_full
((
self
.
_col
.
numel
()
+
1
,
),
-
1
)
idx
[
1
:]
=
self
.
_sparse_size
[
1
]
*
self
.
row
()
+
self
.
_col
mask
=
idx
[
1
:]
>
idx
[:
-
1
]
mask
=
idx
[
1
:]
>
idx
[:
-
1
]
if
mask
.
all
():
# Skip if indices are already coalesced.
if
mask
.
all
():
# Skip if indices are already coalesced.
return
self
return
self
row
=
self
.
row
[
mask
]
row
=
self
.
row
()
[
mask
]
col
=
self
.
col
[
mask
]
col
=
self
.
_
col
[
mask
]
value
=
self
.
value
value
=
self
.
_
value
if
self
.
has_value
()
:
if
value
is
not
None
:
ptr
=
mask
.
nonzero
().
flatten
()
ptr
=
mask
.
nonzero
().
flatten
()
ptr
=
torch
.
cat
([
ptr
,
ptr
.
new_full
((
1
,
),
value
.
size
(
0
))])
ptr
=
torch
.
cat
([
ptr
,
ptr
.
new_full
((
1
,
),
value
.
size
(
0
))])
value
=
segment_csr
(
value
,
ptr
,
reduce
=
reduce
)
raise
NotImplementedError
# value = segment_csr(value, ptr, reduce=reduce)
value
=
value
[
0
]
if
isinstance
(
value
,
tuple
)
else
value
value
=
value
[
0
]
if
isinstance
(
value
,
tuple
)
else
value
return
self
.
__class__
(
row
=
row
,
col
=
col
,
value
=
value
,
return
SparseStorage
(
row
=
row
,
rowptr
=
None
,
col
=
col
,
value
=
value
,
sparse_size
=
self
.
sparse_size
,
is_sorted
=
True
)
sparse_size
=
self
.
_sparse_size
,
rowcount
=
None
,
colptr
=
None
,
colcount
=
None
,
csr2csc
=
None
,
def
cached_keys
(
self
):
csc2csr
=
None
,
is_sorted
=
True
)
return
[
key
for
key
in
self
.
cache_keys
def
fill_cache_
(
self
):
if
getattr
(
self
,
f
'_
{
key
}
'
,
None
)
is
not
None
self
.
row
()
]
self
.
rowptr
()
self
.
rowcount
()
def
fill_cache_
(
self
,
*
args
):
self
.
colptr
()
for
arg
in
args
or
self
.
cache_keys
+
[
'row'
,
'rowptr'
]:
self
.
colcount
()
getattr
(
self
,
arg
)
self
.
csr2csc
()
self
.
csc2csr
()
return
self
return
self
def
clear_cache_
(
self
,
*
args
):
def
clear_cache_
(
self
):
for
arg
in
args
or
self
.
cache_keys
:
self
.
_rowcount
=
None
setattr
(
self
,
f
'_
{
arg
}
'
,
None
)
self
.
_colptr
=
None
self
.
_colcount
=
None
self
.
_csr2csc
=
None
self
.
_csc2csr
=
None
return
self
return
self
def
__copy__
(
self
):
def
__copy__
(
self
):
return
self
.
apply
(
lambda
x
:
x
)
return
SparseStorage
(
row
=
self
.
_row
,
rowptr
=
self
.
_rowptr
,
col
=
self
.
_col
,
value
=
self
.
_value
,
sparse_size
=
self
.
_sparse_size
,
rowcount
=
self
.
_rowcount
,
colptr
=
self
.
_colptr
,
colcount
=
self
.
_colcount
,
csr2csc
=
self
.
_csr2csc
,
csc2csr
=
self
.
_csc2csr
,
is_sorted
=
True
)
def
clone
(
self
):
def
clone
(
self
):
return
self
.
apply
(
lambda
x
:
x
.
clone
())
row
=
self
.
_row
if
row
is
not
None
:
def
__deepcopy__
(
self
,
memo
):
row
=
row
.
clone
()
new_storage
=
self
.
clone
()
rowptr
=
self
.
_rowptr
memo
[
id
(
self
)]
=
new_storage
if
rowptr
is
not
None
:
return
new_storage
rowptr
=
rowptr
.
clone
()
value
=
self
.
_value
def
apply_value_
(
self
,
func
):
if
value
is
not
None
:
self
.
_value
=
optional
(
func
,
self
.
value
)
value
=
value
.
clone
()
return
self
rowcount
=
self
.
_rowcount
if
rowcount
is
not
None
:
def
apply_value
(
self
,
func
):
rowcount
=
rowcount
.
clone
()
return
self
.
__class__
(
row
=
self
.
_row
,
rowptr
=
self
.
_rowptr
,
col
=
self
.
col
,
colptr
=
self
.
_colptr
value
=
optional
(
func
,
self
.
value
),
if
colptr
is
not
None
:
sparse_size
=
self
.
sparse_size
,
colptr
=
colptr
.
clone
()
rowcount
=
self
.
_rowcount
,
colptr
=
self
.
_colptr
,
colcount
=
self
.
_colcount
colcount
=
self
.
_colcount
,
csr2csc
=
self
.
_csr2csc
,
if
colcount
is
not
None
:
csc2csr
=
self
.
_csc2csr
,
is_sorted
=
True
)
colcount
=
colcount
.
clone
()
csr2csc
=
self
.
_csr2csc
def
apply_
(
self
,
func
):
if
csr2csc
is
not
None
:
self
.
_row
=
optional
(
func
,
self
.
_row
)
csr2csc
=
csr2csc
.
clone
()
self
.
_rowptr
=
optional
(
func
,
self
.
_rowptr
)
csc2csr
=
self
.
_csc2csr
self
.
_col
=
func
(
self
.
col
)
if
csc2csr
is
not
None
:
self
.
_value
=
optional
(
func
,
self
.
value
)
csc2csr
=
csc2csr
.
clone
()
for
key
in
self
.
cached_keys
():
return
SparseStorage
(
row
=
row
,
rowptr
=
rowptr
,
col
=
self
.
_col
.
clone
(),
setattr
(
self
,
f
'_
{
key
}
'
,
func
(
getattr
(
self
,
f
'_
{
key
}
'
)))
value
=
value
,
sparse_size
=
self
.
_sparse_size
,
return
self
rowcount
=
rowcount
,
colptr
=
colptr
,
colcount
=
colcount
,
csr2csc
=
csr2csc
,
def
apply
(
self
,
func
):
csc2csr
=
csc2csr
,
is_sorted
=
True
)
return
self
.
__class__
(
row
=
optional
(
func
,
self
.
_row
),
def
__deepcopy__
(
self
,
memo
:
Dict
[
str
,
Any
]):
rowptr
=
optional
(
func
,
self
.
_rowptr
),
return
self
.
clone
()
col
=
func
(
self
.
col
),
value
=
optional
(
func
,
self
.
value
),
sparse_size
=
self
.
sparse_size
,
rowcount
=
optional
(
func
,
self
.
_rowcount
),
colptr
=
optional
(
func
,
self
.
_colptr
),
colcount
=
optional
(
func
,
self
.
_colcount
),
csr2csc
=
optional
(
func
,
self
.
_csr2csc
),
csc2csr
=
optional
(
func
,
self
.
_csc2csr
),
is_sorted
=
True
,
)
def
map
(
self
,
func
):
data
=
[]
if
self
.
has_row
():
data
+=
[
func
(
self
.
row
)]
if
self
.
has_rowptr
():
data
+=
[
func
(
self
.
rowptr
)]
data
+=
[
func
(
self
.
col
)]
if
self
.
has_value
():
data
+=
[
func
(
self
.
value
)]
data
+=
[
func
(
getattr
(
self
,
f
'_
{
key
}
'
))
for
key
in
self
.
cached_keys
()]
return
data
torch_sparse/utils.py
View file @
f59fe649
from
typing
import
Any
import
torch
import
torch
try
:
from
typing_extensions
import
Final
# noqa
except
ImportError
:
from
torch.jit
import
Final
# noqa
torch
.
ops
.
load_library
(
'torch_sparse/convert_cpu.so'
)
torch
.
ops
.
load_library
(
'torch_sparse/convert_cpu.so'
)
torch
.
ops
.
load_library
(
'torch_sparse/diag_cpu.so'
)
torch
.
ops
.
load_library
(
'torch_sparse/diag_cpu.so'
)
torch
.
ops
.
load_library
(
'torch_sparse/spmm_cpu.so'
)
torch
.
ops
.
load_library
(
'torch_sparse/spmm_cpu.so'
)
...
@@ -14,10 +21,5 @@ except OSError as e:
...
@@ -14,10 +21,5 @@ except OSError as e:
raise
e
raise
e
def
ext
(
is_cuda
):
def
is_scalar
(
other
:
Any
)
->
bool
:
name
=
'torch_sparse_cuda'
if
is_cuda
else
'torch_sparse_cpu'
return
getattr
(
torch
.
ops
,
name
)
def
is_scalar
(
other
):
return
isinstance
(
other
,
int
)
or
isinstance
(
other
,
float
)
return
isinstance
(
other
,
int
)
or
isinstance
(
other
,
float
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment