Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
f59fe649
Commit
f59fe649
authored
Jan 27, 2020
by
rusty1s
Browse files
beginning of torch script support
parent
c4484dbb
Changes
7
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
391 additions
and
233 deletions
+391
-233
test/test_jit.py
test/test_jit.py
+80
-0
torch_sparse/coalesce.py
torch_sparse/coalesce.py
+2
-1
torch_sparse/diag.py
torch_sparse/diag.py
+7
-4
torch_sparse/matmul.py
torch_sparse/matmul.py
+8
-2
torch_sparse/spspmm.py
torch_sparse/spspmm.py
+4
-3
torch_sparse/storage.py
torch_sparse/storage.py
+282
-217
torch_sparse/utils.py
torch_sparse/utils.py
+8
-6
No files found.
test/test_jit.py
0 → 100644
View file @
f59fe649
import
torch
from
torch_sparse.tensor
import
SparseTensor
from
torch_sparse.storage
import
SparseStorage
from
typing
import
Dict
,
Any
# class MyTensor(dict):
# def __init__(self, rowptr, col):
# self['rowptr'] = rowptr
# self['col'] = col
# def rowptr(self: Dict[str, torch.Tensor]):
# return self['rowptr']
@
torch
.
jit
.
script
class
Foo
:
rowptr
:
torch
.
Tensor
col
:
torch
.
Tensor
def
__init__
(
self
,
rowptr
:
torch
.
Tensor
,
col
:
torch
.
Tensor
):
self
.
rowptr
=
rowptr
self
.
col
=
col
class
MyCell
(
torch
.
nn
.
Module
):
def
__init__
(
self
):
super
(
MyCell
,
self
).
__init__
()
self
.
linear
=
torch
.
nn
.
Linear
(
2
,
4
)
# def forward(self, x: torch.Tensor, ptr: torch.Tensor) -> torch.Tensor:
def
forward
(
self
,
x
:
torch
.
Tensor
,
adj
:
SparseStorage
)
->
torch
.
Tensor
:
out
,
_
=
torch
.
ops
.
torch_sparse_cpu
.
spmm
(
adj
.
rowptr
(),
adj
.
col
(),
None
,
x
,
'sum'
)
return
out
# ind = torch.ops.torch_sparse_cpu.ptr2ind(ptr, ptr[-1].item())
# # ind = ptr2ind(ptr, E)
# x_j = x[ind]
# out = self.linear(x_j)
# return out
def
test_jit
():
my_cell
=
MyCell
()
# x = torch.rand(3, 2)
# ptr = torch.tensor([0, 2, 4, 6])
# out = my_cell(x, ptr)
# print()
# print(out)
# traced_cell = torch.jit.trace(my_cell, (x, ptr))
# print(traced_cell)
# out = traced_cell(x, ptr)
# print(out)
x
=
torch
.
randn
(
3
,
2
)
# adj = torch.randn(3, 3)
# adj = SparseTensor.from_dense(adj)
# adj = Foo(adj.storage.rowptr, adj.storage.col)
# adj = adj.storage
rowptr
=
torch
.
tensor
([
0
,
3
,
6
,
9
])
col
=
torch
.
tensor
([
0
,
1
,
2
,
0
,
1
,
2
,
0
,
1
,
2
])
adj
=
SparseStorage
(
rowptr
=
rowptr
,
col
=
col
)
# adj = {'rowptr': mat.storage.rowptr, 'col': mat.storage.col}
# foo = Foo(mat.storage.rowptr, mat.storage.col)
# adj = MyTensor(mat.storage.rowptr, mat.storage.col)
traced_cell
=
torch
.
jit
.
script
(
my_cell
)
print
(
traced_cell
)
out
=
traced_cell
(
x
,
adj
)
print
(
out
)
# # print(traced_cell.code)
torch_sparse/coalesce.py
View file @
f59fe649
import
torch
import
torch_scatter
from
.unique
import
unique
#
from .unique import unique
def
coalesce
(
index
,
value
,
m
,
n
,
op
=
'add'
,
fill_value
=
0
):
...
...
@@ -22,6 +22,7 @@ def coalesce(index, value, m, n, op='add', fill_value=0):
:rtype: (:class:`LongTensor`, :class:`Tensor`)
"""
raise
NotImplementedError
row
,
col
=
index
...
...
torch_sparse/diag.py
View file @
f59fe649
import
torch
from
torch_sparse.utils
import
ext
def
remove_diag
(
src
,
k
=
0
):
row
,
col
,
value
=
src
.
coo
()
...
...
@@ -39,8 +37,13 @@ def set_diag(src, values=None, k=0):
row
,
col
,
value
=
src
.
coo
()
mask
=
ext
(
row
.
is_cuda
).
non_diag_mask
(
row
,
col
,
src
.
size
(
0
),
src
.
size
(
1
),
k
)
if
row
.
is_cuda
:
mask
=
torch
.
ops
.
torch_sparse_cuda
.
non_diag_mask
(
row
,
col
,
src
.
size
(
0
),
src
.
size
(
1
),
k
)
else
:
mask
=
torch
.
ops
.
torch_sparse_cpu
.
non_diag_mask
(
row
,
col
,
src
.
size
(
0
),
src
.
size
(
1
),
k
)
inv_mask
=
~
mask
start
,
num_diag
=
-
k
if
k
<
0
else
0
,
mask
.
numel
()
-
row
.
numel
()
...
...
torch_sparse/matmul.py
View file @
f59fe649
import
torch
import
scipy.sparse
from
torch_scatter
import
scatter_add
from
torch_sparse.utils
import
ext
ext
=
None
class
SPMM
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
row
,
rowptr
,
col
,
value
,
mat
,
rowcount
,
colptr
,
csr2csc
,
reduce
):
out
,
arg_out
=
ext
(
mat
.
is_cuda
).
spmm
(
rowptr
,
col
,
value
,
mat
,
reduce
)
if
mat
.
is_cuda
:
out
,
arg_out
=
torch
.
ops
.
torch_sparse_cuda
.
spmm
(
rowptr
,
col
,
value
,
mat
,
reduce
)
else
:
out
,
arg_out
=
torch
.
ops
.
torch_sparse_cpu
.
spmm
(
rowptr
,
col
,
value
,
mat
,
reduce
)
ctx
.
reduce
=
reduce
ctx
.
save_for_backward
(
row
,
rowptr
,
col
,
value
,
mat
,
rowcount
,
colptr
,
...
...
torch_sparse/spspmm.py
View file @
f59fe649
import
torch
from
torch_sparse
import
transpose
,
to_scipy
,
from_scipy
,
coalesce
import
torch_sparse.spspmm_cpu
#
import torch_sparse.spspmm_cpu
if
torch
.
cuda
.
is_available
():
import
torch_sparse.spspmm_cuda
#
if torch.cuda.is_available():
#
import torch_sparse.spspmm_cuda
def
spspmm
(
indexA
,
valueA
,
indexB
,
valueB
,
m
,
k
,
n
,
coalesced
=
False
):
...
...
@@ -25,6 +25,7 @@ def spspmm(indexA, valueA, indexB, valueB, m, k, n, coalesced=False):
:rtype: (:class:`LongTensor`, :class:`Tensor`)
"""
raise
NotImplementedError
if
indexA
.
is_cuda
and
coalesced
:
indexA
,
valueA
=
coalesce
(
indexA
,
valueA
,
m
,
k
)
indexB
,
valueB
=
coalesce
(
indexB
,
valueB
,
k
,
n
)
...
...
torch_sparse/storage.py
View file @
f59fe649
This diff is collapsed.
Click to expand it.
torch_sparse/utils.py
View file @
f59fe649
from
typing
import
Any
import
torch
try
:
from
typing_extensions
import
Final
# noqa
except
ImportError
:
from
torch.jit
import
Final
# noqa
torch
.
ops
.
load_library
(
'torch_sparse/convert_cpu.so'
)
torch
.
ops
.
load_library
(
'torch_sparse/diag_cpu.so'
)
torch
.
ops
.
load_library
(
'torch_sparse/spmm_cpu.so'
)
...
...
@@ -14,10 +21,5 @@ except OSError as e:
raise
e
def
ext
(
is_cuda
):
name
=
'torch_sparse_cuda'
if
is_cuda
else
'torch_sparse_cpu'
return
getattr
(
torch
.
ops
,
name
)
def
is_scalar
(
other
):
def
is_scalar
(
other
:
Any
)
->
bool
:
return
isinstance
(
other
,
int
)
or
isinstance
(
other
,
float
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment