spmm.py 3.36 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
# import torch
rusty1s's avatar
rusty1s committed
2
3
from torch_scatter import scatter_add

rusty1s's avatar
rusty1s committed
4
# from torch_sparse.tensor import SparseTensor
rusty1s's avatar
rusty1s committed
5

rusty1s's avatar
rusty1s committed
6
7
# if torch.cuda.is_available():
#     import torch_sparse.spmm_cuda
rusty1s's avatar
rusty1s committed
8

rusty1s's avatar
rusty1s committed
9
10
11
12
# def spmm_(sparse_mat, mat, reduce='add'):
#     assert reduce in ['add', 'mean', 'min', 'max']
#     assert sparse_mat.dim() == 2 and mat.dim() == 2
#     assert sparse_mat.size(1) == mat.size(0)
rusty1s's avatar
rusty1s committed
13

rusty1s's avatar
rusty1s committed
14
15
#     rowptr, col, value = sparse_mat.csr()
#     mat = mat.contiguous()
rusty1s's avatar
rusty1s committed
16

rusty1s's avatar
rusty1s committed
17
18
19
#     if reduce in ['add', 'mean']:
#         return torch_sparse.spmm_cuda.spmm(rowptr, col, value, mat, reduce)
#     else:
rusty1s's avatar
rusty1s committed
20
21
#         return torch_sparse.spmm_cuda.spmm_arg(
# rowptr, col, value, mat, reduce)
rusty1s's avatar
rusty1s committed
22

rusty1s's avatar
rusty1s committed
23

rusty1s's avatar
rusty1s committed
24
def spmm(index, value, m, n, matrix):
rusty1s's avatar
docs  
rusty1s committed
25
26
27
28
29
    """Matrix product of sparse matrix with dense matrix.

    Args:
        index (:class:`LongTensor`): The index tensor of sparse matrix.
        value (:class:`Tensor`): The value tensor of sparse matrix.
ekagra-ranjan's avatar
ekagra-ranjan committed
30
31
        m (int): The first dimension of corresponding dense matrix.
        n (int): The second dimension of corresponding dense matrix.
rusty1s's avatar
docs  
rusty1s committed
32
33
34
35
        matrix (:class:`Tensor`): The dense matrix.

    :rtype: :class:`Tensor`
    """
rusty1s's avatar
rusty1s committed
36

rusty1s's avatar
rusty1s committed
37
38
    assert n == matrix.size(0)

rusty1s's avatar
rusty1s committed
39
40
41
42
43
44
45
46
    row, col = index
    matrix = matrix if matrix.dim() > 1 else matrix.unsqueeze(-1)

    out = matrix[col]
    out = out * value.unsqueeze(-1)
    out = scatter_add(out, row, dim=0, dim_size=m)

    return out
rusty1s's avatar
rusty1s committed
47
48


rusty1s's avatar
rusty1s committed
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
# if __name__ == '__main__':
#     device = 'cuda' if torch.cuda.is_available() else 'cpu'
#     row = torch.tensor([0, 0, 0, 1, 1, 1], device=device)
#     col = torch.tensor([0, 1, 2, 0, 1, 2], device=device)
#     value = torch.ones_like(col, dtype=torch.float, device=device)
#     value = None
#     sparse_mat = SparseTensor(torch.stack([row, col], dim=0), value)
#     mat = torch.tensor([[1, 4], [2, 5], [3, 6]], dtype=torch.float,
#                        device=device)
#     out1 = spmm_(sparse_mat, mat, reduce='add')
#     out2 = sparse_mat.to_dense() @ mat
#     assert torch.allclose(out1, out2)

#     from torch_geometric.datasets import Reddit, Planetoid  # noqa
#     import time  # noqa

#     # Warmup
#     x = torch.randn((1000, 1000), device=device)
#     for _ in range(100):
#         x.sum()

#     # dataset = Reddit('/tmp/Reddit')
#     dataset = Planetoid('/tmp/PubMed', 'PubMed')
#     # dataset = Planetoid('/tmp/Cora', 'Cora')
#     data = dataset[0].to(device)
#     mat = torch.randn((data.num_nodes, 1024), device=device)
#     value = torch.ones(data.num_edges, device=device)

#     sparse_mat = SparseTensor(data.edge_index, value)
#     torch.cuda.synchronize()
#     t = time.perf_counter()
#     for _ in range(100):
#         out1 = spmm_(sparse_mat, mat, reduce='add')
#         out1 = out1[0] if isinstance(out1, tuple) else out1
#     torch.cuda.synchronize()
#     print('My:   ', time.perf_counter() - t)

#     sparse_mat = torch.sparse_coo_tensor(data.edge_index, value)
#     sparse_mat = sparse_mat.coalesce()

#     torch.cuda.synchronize()
#     t = time.perf_counter()
#     for _ in range(100):
#         out2 = sparse_mat @ mat
#     torch.cuda.synchronize()
#     print('Torch: ', time.perf_counter() - t)

#     torch.cuda.synchronize()
#     t = time.perf_counter()
#     for _ in range(100):
#         spmm(data.edge_index, value, data.num_nodes, data.num_nodes, mat)
#     torch.cuda.synchronize()
#     print('Scatter:', time.perf_counter() - t)

#     assert torch.allclose(out1, out2, atol=1e-2)