spmm.py 3.24 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
import torch
rusty1s's avatar
rusty1s committed
2
3
from torch_scatter import scatter_add

rusty1s's avatar
rusty1s committed
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
from torch_sparse.sparse import SparseTensor

if torch.cuda.is_available():
    import torch_sparse.spmm_cuda


def spmm_(sparse_mat, mat, reduce='add'):
    assert reduce in ['add', 'mean', 'min', 'max']
    assert sparse_mat.dim() == 2 and mat.dim() == 2
    assert sparse_mat.size(1) == mat.size(0)

    rowptr, col, value = sparse_mat.csr()
    mat = mat.contiguous()

    if reduce in ['add', 'mean']:
        return torch_sparse.spmm_cuda.spmm(rowptr, col, value, mat, reduce)
    else:
        return torch_sparse.spmm_cuda.spmm_arg(rowptr, col, value, mat, reduce)

rusty1s's avatar
rusty1s committed
23

rusty1s's avatar
rusty1s committed
24
def spmm(index, value, m, n, matrix):
rusty1s's avatar
docs  
rusty1s committed
25
26
27
28
29
    """Matrix product of sparse matrix with dense matrix.

    Args:
        index (:class:`LongTensor`): The index tensor of sparse matrix.
        value (:class:`Tensor`): The value tensor of sparse matrix.
ekagra-ranjan's avatar
ekagra-ranjan committed
30
31
        m (int): The first dimension of corresponding dense matrix.
        n (int): The second dimension of corresponding dense matrix.
rusty1s's avatar
docs  
rusty1s committed
32
33
34
35
        matrix (:class:`Tensor`): The dense matrix.

    :rtype: :class:`Tensor`
    """
rusty1s's avatar
rusty1s committed
36

rusty1s's avatar
rusty1s committed
37
38
    assert n == matrix.size(0)

rusty1s's avatar
rusty1s committed
39
40
41
42
43
44
45
46
    row, col = index
    matrix = matrix if matrix.dim() > 1 else matrix.unsqueeze(-1)

    out = matrix[col]
    out = out * value.unsqueeze(-1)
    out = scatter_add(out, row, dim=0, dim_size=m)

    return out
rusty1s's avatar
rusty1s committed
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103


if __name__ == '__main__':
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    row = torch.tensor([0, 0, 0, 1, 1, 1], device=device)
    col = torch.tensor([0, 1, 2, 0, 1, 2], device=device)
    value = torch.ones_like(col, dtype=torch.float, device=device)
    value = None
    sparse_mat = SparseTensor(torch.stack([row, col], dim=0), value)
    mat = torch.tensor([[1, 4], [2, 5], [3, 6]], dtype=torch.float,
                       device=device)
    out1 = spmm_(sparse_mat, mat, reduce='add')
    out2 = sparse_mat.to_dense() @ mat
    assert torch.allclose(out1, out2)

    from torch_geometric.datasets import Reddit, Planetoid  # noqa
    import time  # noqa

    # Warmup
    x = torch.randn((1000, 1000), device=device)
    for _ in range(100):
        x.sum()

    # dataset = Reddit('/tmp/Reddit')
    dataset = Planetoid('/tmp/PubMed', 'PubMed')
    # dataset = Planetoid('/tmp/Cora', 'Cora')
    data = dataset[0].to(device)
    mat = torch.randn((data.num_nodes, 1024), device=device)
    value = torch.ones(data.num_edges, device=device)

    sparse_mat = SparseTensor(data.edge_index, value)
    torch.cuda.synchronize()
    t = time.perf_counter()
    for _ in range(100):
        out1 = spmm_(sparse_mat, mat, reduce='add')
        out1 = out1[0] if isinstance(out1, tuple) else out1
    torch.cuda.synchronize()
    print('My:   ', time.perf_counter() - t)

    sparse_mat = torch.sparse_coo_tensor(data.edge_index, value)
    sparse_mat = sparse_mat.coalesce()

    torch.cuda.synchronize()
    t = time.perf_counter()
    for _ in range(100):
        out2 = sparse_mat @ mat
    torch.cuda.synchronize()
    print('Torch: ', time.perf_counter() - t)

    torch.cuda.synchronize()
    t = time.perf_counter()
    for _ in range(100):
        spmm(data.edge_index, value, data.num_nodes, data.num_nodes, mat)
    torch.cuda.synchronize()
    print('Scatter:', time.perf_counter() - t)

    assert torch.allclose(out1, out2, atol=1e-2)