Commit 9732a518 authored by rusty1s's avatar rusty1s
Browse files

torch sparse convert + transpose cleanup

parent 0b790779
import torch
from torch_sparse import to_scipy, from_scipy
from torch_sparse import to_torch_sparse, from_torch_sparse
def test_convert_scipy():
index = torch.tensor([[0, 0, 1, 2, 2], [0, 2, 1, 0, 1]])
value = torch.Tensor([1, 2, 4, 1, 3])
N = 3
out = from_scipy(to_scipy(index, value, N, N))
assert out[0].tolist() == index.tolist()
assert out[1].tolist() == value.tolist()
def test_convert_torch_sparse():
index = torch.tensor([[0, 0, 1, 2, 2], [0, 2, 1, 0, 1]])
value = torch.Tensor([1, 2, 4, 1, 3])
N = 3
out = from_torch_sparse(to_torch_sparse(index, value, N, N).coalesce())
assert out[0].tolist() == index.tolist()
assert out[1].tolist() == value.tolist()
...@@ -2,29 +2,31 @@ from itertools import product ...@@ -2,29 +2,31 @@ from itertools import product
import pytest import pytest
import torch import torch
from torch_sparse import transpose, transpose_matrix from torch_sparse import transpose
from .utils import dtypes, devices, tensor from .utils import dtypes, devices, tensor
def test_transpose(): @pytest.mark.parametrize('dtype,device', product(dtypes, devices))
row = torch.tensor([1, 0, 1, 0, 2, 1]) def test_transpose_matrix(dtype, device):
col = torch.tensor([0, 1, 1, 1, 0, 0]) row = torch.tensor([1, 0, 1, 2], device=device)
col = torch.tensor([0, 1, 1, 0], device=device)
index = torch.stack([row, col], dim=0) index = torch.stack([row, col], dim=0)
value = torch.tensor([[1, 2], [2, 3], [3, 4], [4, 5], [5, 6], [6, 7]]) value = tensor([1, 2, 3, 4], dtype, device)
index, value = transpose(index, value, m=3, n=2) index, value = transpose(index, value, m=3, n=2)
assert index.tolist() == [[0, 0, 1, 1], [1, 2, 0, 1]] assert index.tolist() == [[0, 0, 1, 1], [1, 2, 0, 1]]
assert value.tolist() == [[7, 9], [5, 6], [6, 8], [3, 4]] assert value.tolist() == [1, 4, 2, 3]
@pytest.mark.parametrize('dtype,device', product(dtypes, devices)) @pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_transpose_matrix(dtype, device): def test_transpose(dtype, device):
row = torch.tensor([1, 0, 1, 2], device=device) row = torch.tensor([1, 0, 1, 0, 2, 1], device=device)
col = torch.tensor([0, 1, 1, 0], device=device) col = torch.tensor([0, 1, 1, 1, 0, 0], device=device)
index = torch.stack([row, col], dim=0) index = torch.stack([row, col], dim=0)
value = tensor([1, 2, 3, 4], dtype, device) value = tensor([[1, 2], [2, 3], [3, 4], [4, 5], [5, 6], [6, 7]], dtype,
device)
index, value = transpose_matrix(index, value, m=3, n=2) index, value = transpose(index, value, m=3, n=2)
assert index.tolist() == [[0, 0, 1, 1], [1, 2, 0, 1]] assert index.tolist() == [[0, 0, 1, 1], [1, 2, 0, 1]]
assert value.tolist() == [1, 4, 2, 3] assert value.tolist() == [[7, 9], [5, 6], [6, 8], [3, 4]]
from .convert import to_scipy, from_scipy from .convert import to_torch_sparse, from_torch_sparse, to_scipy, from_scipy
from .coalesce import coalesce from .coalesce import coalesce
from .transpose import transpose, transpose_matrix from .transpose import transpose
from .eye import eye from .eye import eye
from .spmm import spmm from .spmm import spmm
from .spspmm import spspmm from .spspmm import spspmm
...@@ -9,11 +9,12 @@ __version__ = '0.3.0' ...@@ -9,11 +9,12 @@ __version__ = '0.3.0'
__all__ = [ __all__ = [
'__version__', '__version__',
'to_torch_sparse',
'from_torch_sparse',
'to_scipy', 'to_scipy',
'from_scipy', 'from_scipy',
'coalesce', 'coalesce',
'transpose', 'transpose',
'transpose_matrix',
'eye', 'eye',
'spmm', 'spmm',
'spspmm', 'spspmm',
......
...@@ -4,6 +4,14 @@ import torch ...@@ -4,6 +4,14 @@ import torch
from torch import from_numpy from torch import from_numpy
def to_torch_sparse(index, value, m, n):
return torch.sparse_coo_tensor(index.detach(), value, torch.Size([m, n]))
def from_torch_sparse(A):
return A.indices().detach(), A.values()
def to_scipy(index, value, m, n): def to_scipy(index, value, m, n):
assert not index.is_cuda and not value.is_cuda assert not index.is_cuda and not value.is_cuda
(row, col), data = index.detach(), value.detach() (row, col), data = index.detach(), value.detach()
......
import torch import torch
from torch_sparse import transpose_matrix, to_scipy, from_scipy from torch_sparse import transpose, to_scipy, from_scipy
import torch_sparse.spspmm_cpu import torch_sparse.spspmm_cpu
...@@ -53,9 +53,8 @@ class SpSpMM(torch.autograd.Function): ...@@ -53,9 +53,8 @@ class SpSpMM(torch.autograd.Function):
valueB, m, k) valueB, m, k)
if ctx.needs_input_grad[3]: if ctx.needs_input_grad[3]:
indexA, valueA = transpose_matrix(indexA, valueA, m, k) indexA, valueA = transpose(indexA, valueA, m, k)
indexC, grad_valueC = transpose_matrix(indexC, grad_valueC, m, indexC, grad_valueC = transpose(indexC, grad_valueC, m, n)
n)
grad_valueB = torch_sparse.spspmm_cpu.spspmm_bw( grad_valueB = torch_sparse.spspmm_cpu.spspmm_bw(
indexB, indexA.detach(), valueA, indexC.detach(), indexB, indexA.detach(), valueA, indexC.detach(),
grad_valueC, k, n) grad_valueC, k, n)
...@@ -66,7 +65,7 @@ class SpSpMM(torch.autograd.Function): ...@@ -66,7 +65,7 @@ class SpSpMM(torch.autograd.Function):
indexB.detach(), valueB, m, k) indexB.detach(), valueB, m, k)
if ctx.needs_input_grad[3]: if ctx.needs_input_grad[3]:
indexA_T, valueA_T = transpose_matrix(indexA, valueA, m, k) indexA_T, valueA_T = transpose(indexA, valueA, m, k)
grad_indexB, grad_valueB = mm(indexA_T, valueA_T, indexC, grad_indexB, grad_valueB = mm(indexA_T, valueA_T, indexC,
grad_valueC, k, m, n) grad_valueC, k, m, n)
grad_valueB = lift(grad_indexB, grad_valueB, indexB, n) grad_valueB = lift(grad_indexB, grad_valueB, indexB, n)
......
...@@ -14,31 +14,13 @@ def transpose(index, value, m, n): ...@@ -14,31 +14,13 @@ def transpose(index, value, m, n):
:rtype: (:class:`LongTensor`, :class:`Tensor`) :rtype: (:class:`LongTensor`, :class:`Tensor`)
""" """
row, col = index if value.dim() == 1 and not value.is_cuda:
index = torch.stack([col, row], dim=0)
index, value = coalesce(index, value, n, m)
return index, value
def transpose_matrix(index, value, m, n):
"""Transposes dimensions 0 and 1 of a sparse matrix, where :args:`value` is
one-dimensional.
Args:
index (:class:`LongTensor`): The index tensor of sparse matrix.
value (:class:`Tensor`): The value tensor of sparse matrix.
m (int): The first dimension of sparse matrix.
n (int): The second dimension of sparse matrix.
:rtype: (:class:`LongTensor`, :class:`Tensor`)
"""
assert value.dim() == 1
if index.is_cuda:
return transpose(index, value, m, n)
else:
mat = to_scipy(index, value, m, n).tocsc() mat = to_scipy(index, value, m, n).tocsc()
(col, row), value = from_scipy(mat) (col, row), value = from_scipy(mat)
index = torch.stack([row, col], dim=0) index = torch.stack([row, col], dim=0)
return index, value return index, value
row, col = index
index = torch.stack([col, row], dim=0)
index, value = coalesce(index, value, n, m)
return index, value
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment