Commit 5788c855 authored by rusty1s's avatar rusty1s
Browse files

clean up code

parent ef2c346f
......@@ -2,7 +2,7 @@ from itertools import product
import pytest
import torch
from torch_sparse import spspmm, SparseTensor
from torch_sparse import spspmm
from .utils import dtypes, devices, tensor
......@@ -17,9 +17,13 @@ def test_spspmm(dtype, device):
value = tensor([2, 4], dtype, device)
B = (index, value, torch.Size([3, 2]))
index, value = spspmm(*A, *B)
out = SparseTensor(index, value, torch.Size([3, 2]))
assert out.to_dense().tolist() == [[8, 0], [0, 6], [0, 8]]
index, value, size = spspmm(*A, *B)
print(index)
print(value)
print(size)
# out = torch.sparse_coo_tensor(index, value, size)
# assert out.to_dense().tolist() == [[8, 0], [0, 6], [0, 8]]
# TODO TEST backward
# value.sum().backward()
from .sparse import SparseTensor
from .matmul import spspmm
from .transpose import transpose
__all__ = [
'SparseTensor',
'spspmm',
'transpose',
]
import torch
from torch import from_numpy
from scipy.sparse import coo_matrix
import scipy.sparse
from torch_sparse import transpose
from torch_sparse import SparseTensor
import matmul_cuda
if torch.cuda.is_available():
import matmul_cuda
def spspmm(indexA, valueA, sizeA, indexB, valueB, sizeB):
assert valueA.dtype == valueB.dtype
assert len(sizeA) == len(sizeB) == 2
assert sizeA[1] == sizeB[0]
index, value = SpSpMM.apply(indexA, valueA, sizeA, indexB, valueB, sizeB)
size = torch.Size([sizeA[0], sizeB[1]])
return index, value, size
class SpSpMM(torch.autograd.Function):
@staticmethod
def forward(ctx, e1, v1, s1, e2, v2, s2):
e, v = mm(e1, v1, s1, e2, v2, s2)
def forward(ctx, indexA, valueA, sizeA, indexB, valueB, sizeB):
index, value = mm(indexA, valueA, sizeA, indexB, valueB, sizeB)
ctx.s1, ctx.s2 = s1, s2
ctx.save_for_backward(e1, v1, e2, v2, e)
ctx.sizeA, ctx.sizeB = sizeA, sizeB
ctx.save_for_backward(indexA, valueA, indexB, valueB, index)
return e, v
return index, value
@staticmethod
def backward(ctx, grad_e, grad_v):
e1, v1, e2, v2, e = ctx.saved_variables
grad_v1 = grad_v2 = None
grad = (e, grad_v, torch.Size([ctx.s1[0], ctx.s2[1]]))
def backward(ctx, grad_index, grad_value):
indexA, valueA, indexB, valueB, index = ctx.saved_variables
grad_valueA = grad_valueB = None
grad = (index, grad_value, torch.Size([ctx.sizeA[0], ctx.sizeB[1]]))
if ctx.needs_input_grad[1]:
e2 = torch.stack([e2[1], e2[0]], dim=0)
_, grad_v1 = mm(*grad, e2, v2, torch.Size([ctx.s2[1], ctx.s2[0]]))
B_tranposed = transpose(indexB, valueB, ctx.sizeB)
_, grad_valueA = mm(*grad, *B_tranposed)
if ctx.needs_input_grad[4]:
e1 = torch.stack([e1[1], e1[0]], dim=0)
_, grad_v2 = mm(e1, v1, torch.Size([ctx.s1[1], ctx.s1[0]]), *grad)
A_tranposed = transpose(indexA, valueA, ctx.sizeA)
_, grad_valueB = mm(*A_tranposed, *grad)
return None, grad_v1, None, None, grad_v2, None
return None, grad_valueA, None, None, grad_valueB, None
spspmm = SpSpMM.apply
def mm(indexA, valueA, sizeA, indexB, valueB, sizeB):
if valueA.is_cuda:
return mm_cuda(indexA, valueA, sizeA, indexB, valueB, sizeB)
else:
return mm_cpu(indexA, valueA, sizeA, indexB, valueB, sizeB)
def mm(e1, v1, s1, e2, v2, s2):
if v1.is_cuda:
return mm_cuda(e1, v1, s1, e2, v2, s2)
else:
return mm_cpu(e1, v1, s1, e2, v2, s2)
def mm_cuda(indexA, valueA, sizeA, indexB, valueB, sizeB):
A = torch.sparse_coo_tensor(indexA, valueA, sizeA)
B = torch.sparse_coo_tensor(indexB, valueB, sizeB)
index, value = matmul_cuda.spspmm(A, B)
return index, value
def mm_cuda(e1, v1, s1, e2, v2, s2):
matrix1 = SparseTensor(e1, v1, s1)
matrix2 = SparseTensor(e2, v2, s2)
return matmul_cuda.spspmm(matrix1, matrix2)
def mm_cpu(indexA, valueA, sizeA, indexB, valueB, sizeB):
A, B, = to_scipy(indexA, valueA, sizeA), to_scipy(indexB, valueB, sizeB)
C = A.tocsr().dot(B.tocsr()).tocoo()
row, col = torch.from_numpy(C.row).long(), torch.from_numpy(C.col).long()
index = torch.stack([row, col], dim=0)
value = torch.from_numpy(C.data).type_as(valueA)
def mm_cpu(e1, v1, s1, e2, v2, s2):
matrix1, matrix2, = to_csr(e1, v1, s1), to_csr(e2, v2, s2)
out = matrix1.dot(matrix2).tocoo()
row, col = from_numpy(out.row).long(), from_numpy(out.col).long()
return torch.stack([row, col], dim=0), from_numpy(out.data)
return index, value
def to_csr(index, value, size):
index, value = index.detach().numpy(), value.detach().numpy()
shape = (size[0], size[1])
return coo_matrix((value, (index[0], index[1])), shape).tocsr()
def to_scipy(index, value, size):
(row, col), value = index.detach().numpy(), value.detach().numpy()
return scipy.sparse.coo_matrix((value, (row, col)), tuple(size))
import torch
def transpose(index, value, size):
(row, col), (dim1, dim2) = index, size
index, size = torch.stack([col, row], dim=0), torch.Size([dim2, dim1])
return index, value, size
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment