Commit 7da1c4c1 authored by rusty1s's avatar rusty1s
Browse files

restructure

parent 52dcc2e5
...@@ -30,12 +30,9 @@ static void init_cusparse() { ...@@ -30,12 +30,9 @@ static void init_cusparse() {
std::tuple<at::Tensor, at::Tensor> spspmm_cuda(at::Tensor A, at::Tensor B) { std::tuple<at::Tensor, at::Tensor> spspmm_cuda(at::Tensor A, at::Tensor B) {
init_cusparse(); init_cusparse();
A = A.coalesce();
B = B.coalesce();
auto m = A.size(0); auto m = A.size(0);
auto n = B.size(1);
auto k = A.size(1); auto k = A.size(1);
auto n = B.size(1);
auto nnzA = A._nnz(); auto nnzA = A._nnz();
auto nnzB = B._nnz(); auto nnzB = B._nnz();
......
from .coalesce import coalesce from .coalesce import coalesce
from .sparse import sparse_coo_tensor, to_value from .transpose import transpose
from .matmul import spspmm from .matmul import spspmm
__all__ = [ __all__ = [
'coalesce', 'coalesce',
'sparse_coo_tensor', 'transpose',
'to_value',
'spspmm', 'spspmm',
] ]
...@@ -2,8 +2,7 @@ import torch ...@@ -2,8 +2,7 @@ import torch
import torch_scatter import torch_scatter
def coalesce(index, value, size, op='add', fill_value=0): def coalesce(index, value, m, n, op='add', fill_value=0):
m, n = size
row, col = index row, col = index
unique, inv = torch.unique(row * n + col, sorted=True, return_inverse=True) unique, inv = torch.unique(row * n + col, sorted=True, return_inverse=True)
......
import torch import torch
import scipy.sparse import scipy.sparse
from torch_sparse import transpose
if torch.cuda.is_available(): if torch.cuda.is_available():
import matmul_cuda import matmul_cuda
...@@ -9,53 +10,54 @@ class SpSpMM(torch.autograd.Function): ...@@ -9,53 +10,54 @@ class SpSpMM(torch.autograd.Function):
"""Sparse matrix product of two sparse tensors with autograd support.""" """Sparse matrix product of two sparse tensors with autograd support."""
@staticmethod @staticmethod
def forward(ctx, A, B): def forward(ctx, indexA, valueA, indexB, valueB, m, k, n):
ctx.save_for_backward(A, B) indexC, valueC = mm(indexA, valueA, indexB, valueB, m, k, n)
return mm(A, B) ctx.m, ctx.k, ctx.n = m, k, n
ctx.save_for_backward(indexA, valueA, indexB, valueB, indexC)
return indexC, valueC
@staticmethod @staticmethod
def backward(ctx, grad_C): def backward(ctx, grad_indexC, grad_valueC):
A, B = ctx.saved_variables m, k, n = ctx.m, ctx.k, ctx.n
grad_A = grad_B = None indexA, valueA, indexB, valueB, indexC = ctx.saved_variables
if ctx.needs_input_grad[0]: grad_valueA = grad_valueB = None
grad_A = mm(grad_C, B.t().coalesce())
if ctx.needs_input_grad[1]: if ctx.needs_input_grad[1]:
grad_B = mm(A.t(), grad_C) indexB, valueB = transpose(indexB, valueB, k, n)
_, grad_valueA = mm(indexC, grad_valueC, indexB, valueB, m, n, k)
# TODO: Filter values.
return grad_A, grad_B if ctx.needs_input_grad[4]:
indexA, valueA = transpose(indexA, valueA, m, k)
_, grad_valueB = mm(indexA, valueA, indexC, grad_valueC, k, m, n)
# TODO: Filter values.
return None, grad_valueA, None, grad_valueB, None, None, None
spspmm = SpSpMM.apply
spspmm = SpSpMM.apply
def mm(A, B):
assert A.dtype == B.dtype
assert A.size(1) == B.size(0)
return mm_cuda(A, B) if A.is_cuda else mm_cpu(A, B)
def mm(indexA, valueA, indexB, valueB, m, k, n):
assert valueA.dtype == valueB.dtype
def mm_cuda(A, B): if indexA.is_cuda:
index, value = matmul_cuda.spspmm(A, B) return matmul_cuda.spspmm(indexA, valueA, indexB, valueB, m, k, n)
size = torch.Size([A.size(0), B.size(1)])
return torch.sparse_coo_tensor(index, value, size, device=value.device)
A = to_scipy(indexA, valueA, m, k)
B = to_scipy(indexB, valueB, k, n)
indexC, valueC = from_scipy(A.tocsr().dot(B.tocsr()).tocoo())
def mm_cpu(A, B): return indexC, valueC
return from_scipy(to_scipy(A).dot(to_scipy(B)))
def to_scipy(A): def to_scipy(index, value, m, n):
(row, col), data, shape = A._indices(), A._values(), tuple(A.size()) (row, col), data = index.detach(), value.detach()
row, col, data = row.detach(), col.detach(), data.detach() return scipy.sparse.coo_matrix((data, (row, col)), (m, n))
return scipy.sparse.coo_matrix((data, (row, col)), shape).tocsr()
def from_scipy(A): def from_scipy(A):
A = A.tocoo() row, col, value = A.row, A.col, A.data
row, col, value, size = A.row, A.col, A.data, torch.Size(A.shape) index = torch.stack([row, col], dim=0).to(torch.long)
row, col = torch.from_numpy(row).long(), torch.from_numpy(col).long() return index, value
value = torch.from_numpy(value)
index = torch.stack([row, col], dim=0)
return torch.sparse_coo_tensor(index, value, size)
import torch
from torch_sparse import coalesce
def transpose(index, value, m, n):
row, col = index
index = torch.stack([col, row], dim=0)
index, value = coalesce(index, value, m, n)
return index, value
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment