matmul.py 1.88 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
import torch
rusty1s's avatar
rusty1s committed
2
import scipy.sparse
rusty1s's avatar
rusty1s committed
3
from torch_sparse import transpose
rusty1s's avatar
rusty1s committed
4

rusty1s's avatar
rusty1s committed
5
6
7
8
if torch.cuda.is_available():
    import matmul_cuda


rusty1s's avatar
typo  
rusty1s committed
9
class SpSpMM(torch.autograd.Function):
rusty1s's avatar
docs  
rusty1s committed
10
11
    """Sparse matrix product of two sparse tensors with autograd support."""

rusty1s's avatar
rusty1s committed
12
    @staticmethod
rusty1s's avatar
rusty1s committed
13
14
15
16
17
    def forward(ctx, indexA, valueA, indexB, valueB, m, k, n):
        indexC, valueC = mm(indexA, valueA, indexB, valueB, m, k, n)
        ctx.m, ctx.k, ctx.n = m, k, n
        ctx.save_for_backward(indexA, valueA, indexB, valueB, indexC)
        return indexC, valueC
rusty1s's avatar
rusty1s committed
18
19

    @staticmethod
rusty1s's avatar
rusty1s committed
20
21
22
    def backward(ctx, grad_indexC, grad_valueC):
        m, k, n = ctx.m, ctx.k, ctx.n
        indexA, valueA, indexB, valueB, indexC = ctx.saved_variables
rusty1s's avatar
rusty1s committed
23

rusty1s's avatar
rusty1s committed
24
        grad_valueA = grad_valueB = None
rusty1s's avatar
rusty1s committed
25

rusty1s's avatar
rusty1s committed
26
        if ctx.needs_input_grad[1]:
rusty1s's avatar
rusty1s committed
27
28
29
            indexB, valueB = transpose(indexB, valueB, k, n)
            _, grad_valueA = mm(indexC, grad_valueC, indexB, valueB, m, n, k)
            # TODO: Filter values.
rusty1s's avatar
rusty1s committed
30

rusty1s's avatar
rusty1s committed
31
32
33
34
        if ctx.needs_input_grad[4]:
            indexA, valueA = transpose(indexA, valueA, m, k)
            _, grad_valueB = mm(indexA, valueA, indexC, grad_valueC, k, m, n)
            # TODO: Filter values.
rusty1s's avatar
rusty1s committed
35

rusty1s's avatar
rusty1s committed
36
        return None, grad_valueA, None, grad_valueB, None, None, None
rusty1s's avatar
rusty1s committed
37
38


rusty1s's avatar
rusty1s committed
39
spspmm = SpSpMM.apply
rusty1s's avatar
rusty1s committed
40
41


rusty1s's avatar
rusty1s committed
42
43
def mm(indexA, valueA, indexB, valueB, m, k, n):
    assert valueA.dtype == valueB.dtype
rusty1s's avatar
rusty1s committed
44

rusty1s's avatar
rusty1s committed
45
46
    if indexA.is_cuda:
        return matmul_cuda.spspmm(indexA, valueA, indexB, valueB, m, k, n)
rusty1s's avatar
rusty1s committed
47

rusty1s's avatar
rusty1s committed
48
49
50
    A = to_scipy(indexA, valueA, m, k)
    B = to_scipy(indexB, valueB, k, n)
    indexC, valueC = from_scipy(A.tocsr().dot(B.tocsr()).tocoo())
rusty1s's avatar
rusty1s committed
51

rusty1s's avatar
rusty1s committed
52
    return indexC, valueC
rusty1s's avatar
rusty1s committed
53
54


rusty1s's avatar
rusty1s committed
55
56
57
def to_scipy(index, value, m, n):
    (row, col), data = index.detach(), value.detach()
    return scipy.sparse.coo_matrix((data, (row, col)), (m, n))
rusty1s's avatar
rusty1s committed
58
59


rusty1s's avatar
rusty1s committed
60
def from_scipy(A):
rusty1s's avatar
rusty1s committed
61
62
63
    row, col, value = A.row, A.col, A.data
    index = torch.stack([row, col], dim=0).to(torch.long)
    return index, value