matmul.py 1.47 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
import torch
rusty1s's avatar
rusty1s committed
2
import scipy.sparse
rusty1s's avatar
rusty1s committed
3

rusty1s's avatar
rusty1s committed
4
5
6
7
if torch.cuda.is_available():
    import matmul_cuda


rusty1s's avatar
typo  
rusty1s committed
8
class SpSpMM(torch.autograd.Function):
rusty1s's avatar
rusty1s committed
9
    @staticmethod
rusty1s's avatar
rusty1s committed
10
11
12
    def forward(ctx, A, B):
        ctx.save_for_backward(A, B)
        return mm(A, B)
rusty1s's avatar
rusty1s committed
13
14

    @staticmethod
rusty1s's avatar
rusty1s committed
15
16
17
18
19
20
    def backward(ctx, grad_C):
        A, B = ctx.saved_variables
        grad_A = grad_B = None

        if ctx.needs_input_grad[0]:
            grad_A = mm(grad_C, B.t().coalesce())
rusty1s's avatar
rusty1s committed
21

rusty1s's avatar
rusty1s committed
22
        if ctx.needs_input_grad[1]:
rusty1s's avatar
rusty1s committed
23
            grad_B = mm(A.t(), grad_C)
rusty1s's avatar
rusty1s committed
24

rusty1s's avatar
rusty1s committed
25
        return grad_A, grad_B
rusty1s's avatar
rusty1s committed
26
27


rusty1s's avatar
rusty1s committed
28
spspmm = SpSpMM.apply
rusty1s's avatar
rusty1s committed
29
30


rusty1s's avatar
rusty1s committed
31
32
33
34
def mm(A, B):
    assert A.dtype == B.dtype
    assert A.size(1) == B.size(0)
    return mm_cuda(A, B) if A.is_cuda else mm_cpu(A, B)
rusty1s's avatar
rusty1s committed
35

rusty1s's avatar
rusty1s committed
36

rusty1s's avatar
rusty1s committed
37
def mm_cuda(A, B):
rusty1s's avatar
rusty1s committed
38
    index, value = matmul_cuda.spspmm(A, B)
rusty1s's avatar
rusty1s committed
39
40
    size = torch.Size([A.size(0), B.size(1)])
    return torch.sparse_coo_tensor(index, value, size, device=value.device)
rusty1s's avatar
rusty1s committed
41
42


rusty1s's avatar
rusty1s committed
43
44
def mm_cpu(A, B):
    return from_scipy(to_scipy(A).dot(to_scipy(B)))
rusty1s's avatar
rusty1s committed
45
46


rusty1s's avatar
rusty1s committed
47
48
49
50
def to_scipy(A):
    (row, col), data, shape = A._indices(), A._values(), tuple(A.size())
    row, col, data = row.detach(), col.detach(), data.detach()
    return scipy.sparse.coo_matrix((data, (row, col)), shape).tocsr()
rusty1s's avatar
rusty1s committed
51
52


rusty1s's avatar
rusty1s committed
53
54
55
def from_scipy(A):
    A = A.tocoo()
    row, col, value, size = A.row, A.col, A.data, torch.Size(A.shape)
rusty1s's avatar
rusty1s committed
56
    row, col = torch.from_numpy(row).long(), torch.from_numpy(col).long()
rusty1s's avatar
rusty1s committed
57
    value = torch.from_numpy(value)
rusty1s's avatar
rusty1s committed
58
    index = torch.stack([row, col], dim=0)
rusty1s's avatar
rusty1s committed
59
    return torch.sparse_coo_tensor(index, value, size)