Commit 9732a518 authored by rusty1s's avatar rusty1s
Browse files

torch sparse convert + transpose cleanup

parent 0b790779
import torch
from torch_sparse import to_scipy, from_scipy
from torch_sparse import to_torch_sparse, from_torch_sparse
def test_convert_scipy():
index = torch.tensor([[0, 0, 1, 2, 2], [0, 2, 1, 0, 1]])
value = torch.Tensor([1, 2, 4, 1, 3])
N = 3
out = from_scipy(to_scipy(index, value, N, N))
assert out[0].tolist() == index.tolist()
assert out[1].tolist() == value.tolist()
def test_convert_torch_sparse():
index = torch.tensor([[0, 0, 1, 2, 2], [0, 2, 1, 0, 1]])
value = torch.Tensor([1, 2, 4, 1, 3])
N = 3
out = from_torch_sparse(to_torch_sparse(index, value, N, N).coalesce())
assert out[0].tolist() == index.tolist()
assert out[1].tolist() == value.tolist()
......@@ -2,29 +2,31 @@ from itertools import product
import pytest
import torch
from torch_sparse import transpose, transpose_matrix
from torch_sparse import transpose
from .utils import dtypes, devices, tensor
def test_transpose():
row = torch.tensor([1, 0, 1, 0, 2, 1])
col = torch.tensor([0, 1, 1, 1, 0, 0])
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_transpose_matrix(dtype, device):
row = torch.tensor([1, 0, 1, 2], device=device)
col = torch.tensor([0, 1, 1, 0], device=device)
index = torch.stack([row, col], dim=0)
value = torch.tensor([[1, 2], [2, 3], [3, 4], [4, 5], [5, 6], [6, 7]])
value = tensor([1, 2, 3, 4], dtype, device)
index, value = transpose(index, value, m=3, n=2)
assert index.tolist() == [[0, 0, 1, 1], [1, 2, 0, 1]]
assert value.tolist() == [[7, 9], [5, 6], [6, 8], [3, 4]]
assert value.tolist() == [1, 4, 2, 3]
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_transpose_matrix(dtype, device):
row = torch.tensor([1, 0, 1, 2], device=device)
col = torch.tensor([0, 1, 1, 0], device=device)
def test_transpose(dtype, device):
row = torch.tensor([1, 0, 1, 0, 2, 1], device=device)
col = torch.tensor([0, 1, 1, 1, 0, 0], device=device)
index = torch.stack([row, col], dim=0)
value = tensor([1, 2, 3, 4], dtype, device)
value = tensor([[1, 2], [2, 3], [3, 4], [4, 5], [5, 6], [6, 7]], dtype,
device)
index, value = transpose_matrix(index, value, m=3, n=2)
index, value = transpose(index, value, m=3, n=2)
assert index.tolist() == [[0, 0, 1, 1], [1, 2, 0, 1]]
assert value.tolist() == [1, 4, 2, 3]
assert value.tolist() == [[7, 9], [5, 6], [6, 8], [3, 4]]
from .convert import to_scipy, from_scipy
from .convert import to_torch_sparse, from_torch_sparse, to_scipy, from_scipy
from .coalesce import coalesce
from .transpose import transpose, transpose_matrix
from .transpose import transpose
from .eye import eye
from .spmm import spmm
from .spspmm import spspmm
......@@ -9,11 +9,12 @@ __version__ = '0.3.0'
__all__ = [
'__version__',
'to_torch_sparse',
'from_torch_sparse',
'to_scipy',
'from_scipy',
'coalesce',
'transpose',
'transpose_matrix',
'eye',
'spmm',
'spspmm',
......
......@@ -4,6 +4,14 @@ import torch
from torch import from_numpy
def to_torch_sparse(index, value, m, n):
return torch.sparse_coo_tensor(index.detach(), value, torch.Size([m, n]))
def from_torch_sparse(A):
return A.indices().detach(), A.values()
def to_scipy(index, value, m, n):
assert not index.is_cuda and not value.is_cuda
(row, col), data = index.detach(), value.detach()
......
import torch
from torch_sparse import transpose_matrix, to_scipy, from_scipy
from torch_sparse import transpose, to_scipy, from_scipy
import torch_sparse.spspmm_cpu
......@@ -53,9 +53,8 @@ class SpSpMM(torch.autograd.Function):
valueB, m, k)
if ctx.needs_input_grad[3]:
indexA, valueA = transpose_matrix(indexA, valueA, m, k)
indexC, grad_valueC = transpose_matrix(indexC, grad_valueC, m,
n)
indexA, valueA = transpose(indexA, valueA, m, k)
indexC, grad_valueC = transpose(indexC, grad_valueC, m, n)
grad_valueB = torch_sparse.spspmm_cpu.spspmm_bw(
indexB, indexA.detach(), valueA, indexC.detach(),
grad_valueC, k, n)
......@@ -66,7 +65,7 @@ class SpSpMM(torch.autograd.Function):
indexB.detach(), valueB, m, k)
if ctx.needs_input_grad[3]:
indexA_T, valueA_T = transpose_matrix(indexA, valueA, m, k)
indexA_T, valueA_T = transpose(indexA, valueA, m, k)
grad_indexB, grad_valueB = mm(indexA_T, valueA_T, indexC,
grad_valueC, k, m, n)
grad_valueB = lift(grad_indexB, grad_valueB, indexB, n)
......
......@@ -14,31 +14,13 @@ def transpose(index, value, m, n):
:rtype: (:class:`LongTensor`, :class:`Tensor`)
"""
row, col = index
index = torch.stack([col, row], dim=0)
index, value = coalesce(index, value, n, m)
return index, value
def transpose_matrix(index, value, m, n):
"""Transposes dimensions 0 and 1 of a sparse matrix, where :args:`value` is
one-dimensional.
Args:
index (:class:`LongTensor`): The index tensor of sparse matrix.
value (:class:`Tensor`): The value tensor of sparse matrix.
m (int): The first dimension of sparse matrix.
n (int): The second dimension of sparse matrix.
:rtype: (:class:`LongTensor`, :class:`Tensor`)
"""
assert value.dim() == 1
if index.is_cuda:
return transpose(index, value, m, n)
else:
if value.dim() == 1 and not value.is_cuda:
mat = to_scipy(index, value, m, n).tocsc()
(col, row), value = from_scipy(mat)
index = torch.stack([row, col], dim=0)
return index, value
row, col = index
index = torch.stack([col, row], dim=0)
index, value = coalesce(index, value, n, m)
return index, value
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment