Commit 7da1c4c1 authored by rusty1s's avatar rusty1s
Browse files

restructure

parent 52dcc2e5
......@@ -30,12 +30,9 @@ static void init_cusparse() {
std::tuple<at::Tensor, at::Tensor> spspmm_cuda(at::Tensor A, at::Tensor B) {
init_cusparse();
A = A.coalesce();
B = B.coalesce();
auto m = A.size(0);
auto n = B.size(1);
auto k = A.size(1);
auto n = B.size(1);
auto nnzA = A._nnz();
auto nnzB = B._nnz();
......
from .coalesce import coalesce
from .sparse import sparse_coo_tensor, to_value
from .transpose import transpose
from .matmul import spspmm
__all__ = [
'coalesce',
'sparse_coo_tensor',
'to_value',
'transpose',
'spspmm',
]
......@@ -2,8 +2,7 @@ import torch
import torch_scatter
def coalesce(index, value, size, op='add', fill_value=0):
m, n = size
def coalesce(index, value, m, n, op='add', fill_value=0):
row, col = index
unique, inv = torch.unique(row * n + col, sorted=True, return_inverse=True)
......
import torch
import scipy.sparse
from torch_sparse import transpose
if torch.cuda.is_available():
import matmul_cuda
......@@ -9,53 +10,54 @@ class SpSpMM(torch.autograd.Function):
"""Sparse matrix product of two sparse tensors with autograd support."""
@staticmethod
def forward(ctx, A, B):
ctx.save_for_backward(A, B)
return mm(A, B)
def forward(ctx, indexA, valueA, indexB, valueB, m, k, n):
indexC, valueC = mm(indexA, valueA, indexB, valueB, m, k, n)
ctx.m, ctx.k, ctx.n = m, k, n
ctx.save_for_backward(indexA, valueA, indexB, valueB, indexC)
return indexC, valueC
@staticmethod
def backward(ctx, grad_C):
A, B = ctx.saved_variables
grad_A = grad_B = None
def backward(ctx, grad_indexC, grad_valueC):
m, k, n = ctx.m, ctx.k, ctx.n
indexA, valueA, indexB, valueB, indexC = ctx.saved_variables
if ctx.needs_input_grad[0]:
grad_A = mm(grad_C, B.t().coalesce())
grad_valueA = grad_valueB = None
if ctx.needs_input_grad[1]:
grad_B = mm(A.t(), grad_C)
indexB, valueB = transpose(indexB, valueB, k, n)
_, grad_valueA = mm(indexC, grad_valueC, indexB, valueB, m, n, k)
# TODO: Filter values.
return grad_A, grad_B
if ctx.needs_input_grad[4]:
indexA, valueA = transpose(indexA, valueA, m, k)
_, grad_valueB = mm(indexA, valueA, indexC, grad_valueC, k, m, n)
# TODO: Filter values.
return None, grad_valueA, None, grad_valueB, None, None, None
spspmm = SpSpMM.apply
spspmm = SpSpMM.apply
def mm(A, B):
assert A.dtype == B.dtype
assert A.size(1) == B.size(0)
return mm_cuda(A, B) if A.is_cuda else mm_cpu(A, B)
def mm(indexA, valueA, indexB, valueB, m, k, n):
assert valueA.dtype == valueB.dtype
def mm_cuda(A, B):
index, value = matmul_cuda.spspmm(A, B)
size = torch.Size([A.size(0), B.size(1)])
return torch.sparse_coo_tensor(index, value, size, device=value.device)
if indexA.is_cuda:
return matmul_cuda.spspmm(indexA, valueA, indexB, valueB, m, k, n)
A = to_scipy(indexA, valueA, m, k)
B = to_scipy(indexB, valueB, k, n)
indexC, valueC = from_scipy(A.tocsr().dot(B.tocsr()).tocoo())
def mm_cpu(A, B):
return from_scipy(to_scipy(A).dot(to_scipy(B)))
return indexC, valueC
def to_scipy(A):
(row, col), data, shape = A._indices(), A._values(), tuple(A.size())
row, col, data = row.detach(), col.detach(), data.detach()
return scipy.sparse.coo_matrix((data, (row, col)), shape).tocsr()
def to_scipy(index, value, m, n):
(row, col), data = index.detach(), value.detach()
return scipy.sparse.coo_matrix((data, (row, col)), (m, n))
def from_scipy(A):
A = A.tocoo()
row, col, value, size = A.row, A.col, A.data, torch.Size(A.shape)
row, col = torch.from_numpy(row).long(), torch.from_numpy(col).long()
value = torch.from_numpy(value)
index = torch.stack([row, col], dim=0)
return torch.sparse_coo_tensor(index, value, size)
row, col, value = A.row, A.col, A.data
index = torch.stack([row, col], dim=0).to(torch.long)
return index, value
import torch
from torch_sparse import coalesce
def transpose(index, value, m, n):
row, col = index
index = torch.stack([col, row], dim=0)
index, value = coalesce(index, value, m, n)
return index, value
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment