Unverified Commit 7671fcb0 authored by Matthias Fey's avatar Matthias Fey Committed by GitHub
Browse files

Merge pull request #33 from rusty1s/adj

[WIP] SparseTensor Format
parents 1fb5fa4f 704ad420
This diff is collapsed.
This diff is collapsed.
import torch
from torch_sparse import to_scipy, from_scipy, coalesce
from torch_sparse.storage import SparseStorage
from torch_sparse.tensor import SparseTensor
@torch.jit.script
def t(src: SparseTensor) -> SparseTensor:
csr2csc = src.storage.csr2csc()
row, col, value = src.coo()
if value is not None:
value = value[csr2csc]
sparse_sizes = src.storage.sparse_sizes()
storage = SparseStorage(
row=col[csr2csc],
rowptr=src.storage._colptr,
col=row[csr2csc],
value=value,
sparse_sizes=torch.Size([sparse_sizes[1], sparse_sizes[0]]),
rowcount=src.storage._colcount,
colptr=src.storage._rowptr,
colcount=src.storage._rowcount,
csr2csc=src.storage._csc2csr,
csc2csr=csr2csc,
is_sorted=True,
)
return src.from_storage(storage)
SparseTensor.t = lambda self: t(self)
###############################################################################
def transpose(index, value, m, n, coalesced=True):
......@@ -15,14 +50,14 @@ def transpose(index, value, m, n, coalesced=True):
:rtype: (:class:`LongTensor`, :class:`Tensor`)
"""
if value.dim() == 1 and not value.is_cuda:
mat = to_scipy(index, value, m, n).tocsc()
(col, row), value = from_scipy(mat)
index = torch.stack([row, col], dim=0)
return index, value
row, col = index
index = torch.stack([col, row], dim=0)
row, col = col, row
if coalesced:
index, value = coalesce(index, value, n, m)
return index, value
sparse_sizes = torch.Size([n, m])
storage = SparseStorage(row=row, col=col, value=value,
sparse_sizes=sparse_sizes, is_sorted=False)
storage = storage.coalesce()
row, col, value = storage.row(), storage.col(), storage.value()
return torch.stack([row, col], dim=0), value
from typing import Any
try:
from typing_extensions import Final # noqa
except ImportError:
from torch.jit import Final # noqa
def is_scalar(other: Any) -> bool:
return isinstance(other, int) or isinstance(other, float)
import torch
import numpy as np
if torch.cuda.is_available():
import torch_sparse.unique_cuda
def unique(src):
src = src.contiguous().view(-1)
if src.is_cuda:
out, perm = torch_sparse.unique_cuda.unique(src)
else:
out, perm = np.unique(src.numpy(), return_index=True)
out, perm = torch.from_numpy(out), torch.from_numpy(perm)
return out, perm
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment