Commit 7f7036cd authored by rusty1s's avatar rusty1s
Browse files

update transpose:

parent 592d63d2
import torch
from torch_sparse.tensor import SparseTensor
from torch_sparse.storage import SparseStorage
from torch_sparse import SparseStorage, SparseTensor
from typing import Dict, Any
......@@ -73,6 +72,10 @@ def test_jit():
# mat = SparseTensor.from_scipy(scipy)
print()
print(adj)
# adj = t(adj)
adj = adj.t()
print(adj)
# print(adj.t)
# adj = {'rowptr': mat.storage.rowptr, 'col': mat.storage.col}
# foo = Foo(mat.storage.rowptr, mat.storage.col)
......
......@@ -19,3 +19,7 @@ __all__ = [
'spmm',
'spspmm',
]
from .storage import SparseStorage
from .tensor import SparseTensor
from .transpose import t
......@@ -6,7 +6,6 @@ import scipy.sparse
from torch_sparse.storage import SparseStorage, get_layout
# from torch_sparse.transpose import t
# from torch_sparse.narrow import narrow
# from torch_sparse.select import select
# from torch_sparse.index_select import index_select, index_select_nnz
......@@ -406,9 +405,6 @@ class SparseTensor(object):
# return matmul(self, other, reduce='sum')
# Bindings ####################################################################
# SparseTensor.t = t
# SparseTensor.narrow = narrow
# SparseTensor.select = select
# SparseTensor.index_select = index_select
......
import torch
from torch_sparse import to_scipy, from_scipy, coalesce
from torch_sparse.storage import SparseStorage
from torch_sparse.tensor import SparseTensor
def transpose(index, value, m, n, coalesced=True):
"""Transposes dimensions 0 and 1 of a sparse tensor.
......@@ -28,15 +31,23 @@ def transpose(index, value, m, n, coalesced=True):
return index, value
def t(src):
csr2csc = src.storage.csr2csc
@torch.jit.script
def t(src: SparseTensor):
csr2csc = src.storage.csr2csc()
row, col, value = src.coo()
if value is not None:
value = value[csr2csc]
storage = src.storage.__class__(
row=src.storage.col[csr2csc],
sparse_sizes = src.storage.sparse_sizes()
storage = SparseStorage(
row=col[csr2csc],
rowptr=src.storage._colptr,
col=src.storage.row[csr2csc],
value=src.storage.value[csr2csc] if src.has_value() else None,
sparse_size=src.storage.sparse_size[::-1],
col=row[csr2csc],
value=value,
sparse_sizes=torch.Size([sparse_sizes[1], sparse_sizes[0]]),
rowcount=src.storage._colcount,
colptr=src.storage._rowptr,
colcount=src.storage._rowcount,
......@@ -46,3 +57,6 @@ def t(src):
)
return src.from_storage(storage)
SparseTensor.t = lambda self: t(self)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment