Commit 7f7036cd authored by rusty1s's avatar rusty1s
Browse files

update transpose:

parent 592d63d2
import torch import torch
from torch_sparse.tensor import SparseTensor from torch_sparse import SparseStorage, SparseTensor
from torch_sparse.storage import SparseStorage
from typing import Dict, Any from typing import Dict, Any
...@@ -73,6 +72,10 @@ def test_jit(): ...@@ -73,6 +72,10 @@ def test_jit():
# mat = SparseTensor.from_scipy(scipy) # mat = SparseTensor.from_scipy(scipy)
print() print()
print(adj) print(adj)
# adj = t(adj)
adj = adj.t()
print(adj)
# print(adj.t)
# adj = {'rowptr': mat.storage.rowptr, 'col': mat.storage.col} # adj = {'rowptr': mat.storage.rowptr, 'col': mat.storage.col}
# foo = Foo(mat.storage.rowptr, mat.storage.col) # foo = Foo(mat.storage.rowptr, mat.storage.col)
......
...@@ -19,3 +19,7 @@ __all__ = [ ...@@ -19,3 +19,7 @@ __all__ = [
'spmm', 'spmm',
'spspmm', 'spspmm',
] ]
from .storage import SparseStorage
from .tensor import SparseTensor
from .transpose import t
...@@ -6,7 +6,6 @@ import scipy.sparse ...@@ -6,7 +6,6 @@ import scipy.sparse
from torch_sparse.storage import SparseStorage, get_layout from torch_sparse.storage import SparseStorage, get_layout
# from torch_sparse.transpose import t
# from torch_sparse.narrow import narrow # from torch_sparse.narrow import narrow
# from torch_sparse.select import select # from torch_sparse.select import select
# from torch_sparse.index_select import index_select, index_select_nnz # from torch_sparse.index_select import index_select, index_select_nnz
...@@ -406,9 +405,6 @@ class SparseTensor(object): ...@@ -406,9 +405,6 @@ class SparseTensor(object):
# return matmul(self, other, reduce='sum') # return matmul(self, other, reduce='sum')
# Bindings ####################################################################
# SparseTensor.t = t
# SparseTensor.narrow = narrow # SparseTensor.narrow = narrow
# SparseTensor.select = select # SparseTensor.select = select
# SparseTensor.index_select = index_select # SparseTensor.index_select = index_select
......
import torch import torch
from torch_sparse import to_scipy, from_scipy, coalesce from torch_sparse import to_scipy, from_scipy, coalesce
from torch_sparse.storage import SparseStorage
from torch_sparse.tensor import SparseTensor
def transpose(index, value, m, n, coalesced=True): def transpose(index, value, m, n, coalesced=True):
"""Transposes dimensions 0 and 1 of a sparse tensor. """Transposes dimensions 0 and 1 of a sparse tensor.
...@@ -28,15 +31,23 @@ def transpose(index, value, m, n, coalesced=True): ...@@ -28,15 +31,23 @@ def transpose(index, value, m, n, coalesced=True):
return index, value return index, value
def t(src): @torch.jit.script
csr2csc = src.storage.csr2csc def t(src: SparseTensor):
csr2csc = src.storage.csr2csc()
row, col, value = src.coo()
if value is not None:
value = value[csr2csc]
storage = src.storage.__class__( sparse_sizes = src.storage.sparse_sizes()
row=src.storage.col[csr2csc],
storage = SparseStorage(
row=col[csr2csc],
rowptr=src.storage._colptr, rowptr=src.storage._colptr,
col=src.storage.row[csr2csc], col=row[csr2csc],
value=src.storage.value[csr2csc] if src.has_value() else None, value=value,
sparse_size=src.storage.sparse_size[::-1], sparse_sizes=torch.Size([sparse_sizes[1], sparse_sizes[0]]),
rowcount=src.storage._colcount, rowcount=src.storage._colcount,
colptr=src.storage._rowptr, colptr=src.storage._rowptr,
colcount=src.storage._rowcount, colcount=src.storage._rowcount,
...@@ -46,3 +57,6 @@ def t(src): ...@@ -46,3 +57,6 @@ def t(src):
) )
return src.from_storage(storage) return src.from_storage(storage)
SparseTensor.t = lambda self: t(self)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment