matmul.py 1.55 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
import torch
rusty1s's avatar
rusty1s committed
2
import scipy.sparse
rusty1s's avatar
rusty1s committed
3

rusty1s's avatar
rusty1s committed
4
5
6
7
if torch.cuda.is_available():
    import matmul_cuda


rusty1s's avatar
typo  
rusty1s committed
8
class SpSpMM(torch.autograd.Function):
rusty1s's avatar
docs  
rusty1s committed
9
10
    """Sparse matrix product of two sparse tensors with autograd support."""

rusty1s's avatar
rusty1s committed
11
    @staticmethod
rusty1s's avatar
rusty1s committed
12
13
14
    def forward(ctx, A, B):
        ctx.save_for_backward(A, B)
        return mm(A, B)
rusty1s's avatar
rusty1s committed
15
16

    @staticmethod
rusty1s's avatar
rusty1s committed
17
18
19
20
21
22
    def backward(ctx, grad_C):
        A, B = ctx.saved_variables
        grad_A = grad_B = None

        if ctx.needs_input_grad[0]:
            grad_A = mm(grad_C, B.t().coalesce())
rusty1s's avatar
rusty1s committed
23

rusty1s's avatar
rusty1s committed
24
        if ctx.needs_input_grad[1]:
rusty1s's avatar
rusty1s committed
25
            grad_B = mm(A.t(), grad_C)
rusty1s's avatar
rusty1s committed
26

rusty1s's avatar
rusty1s committed
27
        return grad_A, grad_B
rusty1s's avatar
rusty1s committed
28
29


rusty1s's avatar
rusty1s committed
30
spspmm = SpSpMM.apply
rusty1s's avatar
rusty1s committed
31
32


rusty1s's avatar
rusty1s committed
33
34
35
36
def mm(A, B):
    assert A.dtype == B.dtype
    assert A.size(1) == B.size(0)
    return mm_cuda(A, B) if A.is_cuda else mm_cpu(A, B)
rusty1s's avatar
rusty1s committed
37

rusty1s's avatar
rusty1s committed
38

rusty1s's avatar
rusty1s committed
39
def mm_cuda(A, B):
rusty1s's avatar
rusty1s committed
40
    index, value = matmul_cuda.spspmm(A, B)
rusty1s's avatar
rusty1s committed
41
42
    size = torch.Size([A.size(0), B.size(1)])
    return torch.sparse_coo_tensor(index, value, size, device=value.device)
rusty1s's avatar
rusty1s committed
43
44


rusty1s's avatar
rusty1s committed
45
46
def mm_cpu(A, B):
    return from_scipy(to_scipy(A).dot(to_scipy(B)))
rusty1s's avatar
rusty1s committed
47
48


rusty1s's avatar
rusty1s committed
49
50
51
52
def to_scipy(A):
    (row, col), data, shape = A._indices(), A._values(), tuple(A.size())
    row, col, data = row.detach(), col.detach(), data.detach()
    return scipy.sparse.coo_matrix((data, (row, col)), shape).tocsr()
rusty1s's avatar
rusty1s committed
53
54


rusty1s's avatar
rusty1s committed
55
56
57
def from_scipy(A):
    A = A.tocoo()
    row, col, value, size = A.row, A.col, A.data, torch.Size(A.shape)
rusty1s's avatar
rusty1s committed
58
    row, col = torch.from_numpy(row).long(), torch.from_numpy(col).long()
rusty1s's avatar
rusty1s committed
59
    value = torch.from_numpy(value)
rusty1s's avatar
rusty1s committed
60
    index = torch.stack([row, col], dim=0)
rusty1s's avatar
rusty1s committed
61
    return torch.sparse_coo_tensor(index, value, size)