spspmm.py 1.41 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
import torch
rusty1s's avatar
rusty1s committed
2
3
from torch_sparse.tensor import SparseTensor
from torch_sparse.matmul import matmul
rusty1s's avatar
rusty1s committed
4
5


6
def spspmm(indexA, valueA, indexB, valueB, m, k, n, coalesced=False):
rusty1s's avatar
docs  
rusty1s committed
7
    """Matrix product of two sparse tensors. Both input sparse matrices need to
rusty1s's avatar
typo  
rusty1s committed
8
    be coalesced (use the :obj:`coalesced` attribute to force).
rusty1s's avatar
docs  
rusty1s committed
9
10
11
12
13
14

    Args:
        indexA (:class:`LongTensor`): The index tensor of first sparse matrix.
        valueA (:class:`Tensor`): The value tensor of first sparse matrix.
        indexB (:class:`LongTensor`): The index tensor of second sparse matrix.
        valueB (:class:`Tensor`): The value tensor of second sparse matrix.
wang-ps's avatar
wang-ps committed
15
16
17
18
        m (int): The first dimension of first sparse matrix.
        k (int): The second dimension of first sparse matrix and first
            dimension of second sparse matrix.
        n (int): The second dimension of second sparse matrix.
rusty1s's avatar
typo  
rusty1s committed
19
        coalesced (bool, optional): If set to :obj:`True`, will coalesce both
rusty1s's avatar
typo  
rusty1s committed
20
            input sparse matrices. (default: :obj:`False`)
rusty1s's avatar
docs  
rusty1s committed
21
22
23

    :rtype: (:class:`LongTensor`, :class:`Tensor`)
    """
rusty1s's avatar
rusty1s committed
24

rusty1s's avatar
rusty1s committed
25
    A = SparseTensor(row=indexA[0], col=indexA[1], value=valueA,
rusty1s's avatar
rusty1s committed
26
                     sparse_sizes=(m, k), is_sorted=not coalesced)
rusty1s's avatar
rusty1s committed
27
    B = SparseTensor(row=indexB[0], col=indexB[1], value=valueB,
rusty1s's avatar
rusty1s committed
28
                     sparse_sizes=(k, n), is_sorted=not coalesced)
rusty1s's avatar
rusty1s committed
29

rusty1s's avatar
rusty1s committed
30
31
    C = matmul(A, B)
    row, col, value = C.coo()
rusty1s's avatar
bw fix  
rusty1s committed
32

rusty1s's avatar
rusty1s committed
33
    return torch.stack([row, col], dim=0), value