Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
7f7036cd
Commit
7f7036cd
authored
Jan 28, 2020
by
rusty1s
Browse files
update transpose:
parent
592d63d2
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
30 additions
and
13 deletions
+30
-13
test/test_jit.py
test/test_jit.py
+5
-2
torch_sparse/__init__.py
torch_sparse/__init__.py
+4
-0
torch_sparse/tensor.py
torch_sparse/tensor.py
+0
-4
torch_sparse/transpose.py
torch_sparse/transpose.py
+21
-7
No files found.
test/test_jit.py
View file @
7f7036cd
import
torch
import
torch
from
torch_sparse.tensor
import
SparseTensor
from
torch_sparse
import
SparseStorage
,
SparseTensor
from
torch_sparse.storage
import
SparseStorage
from
typing
import
Dict
,
Any
from
typing
import
Dict
,
Any
...
@@ -73,6 +72,10 @@ def test_jit():
...
@@ -73,6 +72,10 @@ def test_jit():
# mat = SparseTensor.from_scipy(scipy)
# mat = SparseTensor.from_scipy(scipy)
print
()
print
()
print
(
adj
)
print
(
adj
)
# adj = t(adj)
adj
=
adj
.
t
()
print
(
adj
)
# print(adj.t)
# adj = {'rowptr': mat.storage.rowptr, 'col': mat.storage.col}
# adj = {'rowptr': mat.storage.rowptr, 'col': mat.storage.col}
# foo = Foo(mat.storage.rowptr, mat.storage.col)
# foo = Foo(mat.storage.rowptr, mat.storage.col)
...
...
torch_sparse/__init__.py
View file @
7f7036cd
...
@@ -19,3 +19,7 @@ __all__ = [
...
@@ -19,3 +19,7 @@ __all__ = [
'spmm'
,
'spmm'
,
'spspmm'
,
'spspmm'
,
]
]
from
.storage
import
SparseStorage
from
.tensor
import
SparseTensor
from
.transpose
import
t
torch_sparse/tensor.py
View file @
7f7036cd
...
@@ -6,7 +6,6 @@ import scipy.sparse
...
@@ -6,7 +6,6 @@ import scipy.sparse
from
torch_sparse.storage
import
SparseStorage
,
get_layout
from
torch_sparse.storage
import
SparseStorage
,
get_layout
# from torch_sparse.transpose import t
# from torch_sparse.narrow import narrow
# from torch_sparse.narrow import narrow
# from torch_sparse.select import select
# from torch_sparse.select import select
# from torch_sparse.index_select import index_select, index_select_nnz
# from torch_sparse.index_select import index_select, index_select_nnz
...
@@ -406,9 +405,6 @@ class SparseTensor(object):
...
@@ -406,9 +405,6 @@ class SparseTensor(object):
# return matmul(self, other, reduce='sum')
# return matmul(self, other, reduce='sum')
# Bindings ####################################################################
# SparseTensor.t = t
# SparseTensor.narrow = narrow
# SparseTensor.narrow = narrow
# SparseTensor.select = select
# SparseTensor.select = select
# SparseTensor.index_select = index_select
# SparseTensor.index_select = index_select
...
...
torch_sparse/transpose.py
View file @
7f7036cd
import
torch
import
torch
from
torch_sparse
import
to_scipy
,
from_scipy
,
coalesce
from
torch_sparse
import
to_scipy
,
from_scipy
,
coalesce
from
torch_sparse.storage
import
SparseStorage
from
torch_sparse.tensor
import
SparseTensor
def
transpose
(
index
,
value
,
m
,
n
,
coalesced
=
True
):
def
transpose
(
index
,
value
,
m
,
n
,
coalesced
=
True
):
"""Transposes dimensions 0 and 1 of a sparse tensor.
"""Transposes dimensions 0 and 1 of a sparse tensor.
...
@@ -28,15 +31,23 @@ def transpose(index, value, m, n, coalesced=True):
...
@@ -28,15 +31,23 @@ def transpose(index, value, m, n, coalesced=True):
return
index
,
value
return
index
,
value
def
t
(
src
):
@
torch
.
jit
.
script
csr2csc
=
src
.
storage
.
csr2csc
def
t
(
src
:
SparseTensor
):
csr2csc
=
src
.
storage
.
csr2csc
()
row
,
col
,
value
=
src
.
coo
()
if
value
is
not
None
:
value
=
value
[
csr2csc
]
storage
=
src
.
storage
.
__class__
(
sparse_sizes
=
src
.
storage
.
sparse_sizes
()
row
=
src
.
storage
.
col
[
csr2csc
],
storage
=
SparseStorage
(
row
=
col
[
csr2csc
],
rowptr
=
src
.
storage
.
_colptr
,
rowptr
=
src
.
storage
.
_colptr
,
col
=
src
.
storage
.
row
[
csr2csc
],
col
=
row
[
csr2csc
],
value
=
src
.
storage
.
value
[
csr2csc
]
if
src
.
has_value
()
else
Non
e
,
value
=
valu
e
,
sparse_size
=
src
.
storage
.
sparse_size
[::
-
1
]
,
sparse_size
s
=
torch
.
Size
([
sparse_sizes
[
1
],
sparse_size
s
[
0
]])
,
rowcount
=
src
.
storage
.
_colcount
,
rowcount
=
src
.
storage
.
_colcount
,
colptr
=
src
.
storage
.
_rowptr
,
colptr
=
src
.
storage
.
_rowptr
,
colcount
=
src
.
storage
.
_rowcount
,
colcount
=
src
.
storage
.
_rowcount
,
...
@@ -46,3 +57,6 @@ def t(src):
...
@@ -46,3 +57,6 @@ def t(src):
)
)
return
src
.
from_storage
(
storage
)
return
src
.
from_storage
(
storage
)
SparseTensor
.
t
=
lambda
self
:
t
(
self
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment