matmul.py 1.17 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import torch
from torch import from_numpy
from scipy.sparse import coo_matrix


class SparseSparseMatmul(torch.autograd.Function):
    @staticmethod
    def forward(ctx, matrix1, matrix2):
        ctx.save_for_backawrd(matrix1, matrix2)
        return matmul(matrix1, matrix2)

    @staticmethod
    def backward(ctx, grad_out):
        matrix1, matrix2 = ctx.saved_variables
        grad_matrix1 = grad_matrix2 = None

        if ctx.needs_input_grad[0]:
            grad_matrix1 = matmul(grad_out, matrix2.t())

        if ctx.needs_input_grad[0]:
            grad_matrix2 = matmul(matrix1.t(), grad_out)

        return grad_matrix1, grad_matrix2


sparse_sparse_matmul = SparseSparseMatmul.apply


def matmul(A, B):
    if A[0].is_cuda:
        pass
    else:
        return matmul_cpu(A, B)


def matmul_cpu(A, B):
    A, B, = to_csr(A), to_csr(B)
    C = A.dot(B).tocoo()
    row, col, value = from_numpy(C.row), from_numpy(C.col), from_numpy(C.data)
    return torch.stack([row, col], dim=0), value


def to_csr(A):
    (row, col), value, size = A
    row, col, value = row.numpy(), col.numpy(), value.numpy()
    return coo_matrix((value, (row, col)), shape=(size[0], size[1])).tocsr()