"example/37_permute/permute_1xHxW_fp16.cpp" did not exist on "e294c6756b9d9f6aa5038a3e8c42bbc61e4c97bd"
spmm.py 3.35 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
import torch
rusty1s's avatar
rusty1s committed
2
3
from torch_scatter import scatter_add

rusty1s's avatar
rusty1s committed
4
# from torch_sparse.tensor import SparseTensor
rusty1s's avatar
rusty1s committed
5

rusty1s's avatar
rusty1s committed
6
7
# if torch.cuda.is_available():
#     import torch_sparse.spmm_cuda
rusty1s's avatar
rusty1s committed
8

rusty1s's avatar
rusty1s committed
9
10
11
12
# def spmm_(sparse_mat, mat, reduce='add'):
#     assert reduce in ['add', 'mean', 'min', 'max']
#     assert sparse_mat.dim() == 2 and mat.dim() == 2
#     assert sparse_mat.size(1) == mat.size(0)
rusty1s's avatar
rusty1s committed
13

rusty1s's avatar
rusty1s committed
14
15
#     rowptr, col, value = sparse_mat.csr()
#     mat = mat.contiguous()
rusty1s's avatar
rusty1s committed
16

rusty1s's avatar
rusty1s committed
17
18
19
20
#     if reduce in ['add', 'mean']:
#         return torch_sparse.spmm_cuda.spmm(rowptr, col, value, mat, reduce)
#     else:
#         return torch_sparse.spmm_cuda.spmm_arg(rowptr, col, value, mat, reduce)
rusty1s's avatar
rusty1s committed
21

rusty1s's avatar
rusty1s committed
22

rusty1s's avatar
rusty1s committed
23
def spmm(index, value, m, n, matrix):
rusty1s's avatar
docs  
rusty1s committed
24
25
26
27
28
    """Matrix product of sparse matrix with dense matrix.

    Args:
        index (:class:`LongTensor`): The index tensor of sparse matrix.
        value (:class:`Tensor`): The value tensor of sparse matrix.
ekagra-ranjan's avatar
ekagra-ranjan committed
29
30
        m (int): The first dimension of corresponding dense matrix.
        n (int): The second dimension of corresponding dense matrix.
rusty1s's avatar
docs  
rusty1s committed
31
32
33
34
        matrix (:class:`Tensor`): The dense matrix.

    :rtype: :class:`Tensor`
    """
rusty1s's avatar
rusty1s committed
35

rusty1s's avatar
rusty1s committed
36
37
    assert n == matrix.size(0)

rusty1s's avatar
rusty1s committed
38
39
40
41
42
43
44
45
    row, col = index
    matrix = matrix if matrix.dim() > 1 else matrix.unsqueeze(-1)

    out = matrix[col]
    out = out * value.unsqueeze(-1)
    out = scatter_add(out, row, dim=0, dim_size=m)

    return out
rusty1s's avatar
rusty1s committed
46
47


rusty1s's avatar
rusty1s committed
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
# if __name__ == '__main__':
#     device = 'cuda' if torch.cuda.is_available() else 'cpu'
#     row = torch.tensor([0, 0, 0, 1, 1, 1], device=device)
#     col = torch.tensor([0, 1, 2, 0, 1, 2], device=device)
#     value = torch.ones_like(col, dtype=torch.float, device=device)
#     value = None
#     sparse_mat = SparseTensor(torch.stack([row, col], dim=0), value)
#     mat = torch.tensor([[1, 4], [2, 5], [3, 6]], dtype=torch.float,
#                        device=device)
#     out1 = spmm_(sparse_mat, mat, reduce='add')
#     out2 = sparse_mat.to_dense() @ mat
#     assert torch.allclose(out1, out2)

#     from torch_geometric.datasets import Reddit, Planetoid  # noqa
#     import time  # noqa

#     # Warmup
#     x = torch.randn((1000, 1000), device=device)
#     for _ in range(100):
#         x.sum()

#     # dataset = Reddit('/tmp/Reddit')
#     dataset = Planetoid('/tmp/PubMed', 'PubMed')
#     # dataset = Planetoid('/tmp/Cora', 'Cora')
#     data = dataset[0].to(device)
#     mat = torch.randn((data.num_nodes, 1024), device=device)
#     value = torch.ones(data.num_edges, device=device)

#     sparse_mat = SparseTensor(data.edge_index, value)
#     torch.cuda.synchronize()
#     t = time.perf_counter()
#     for _ in range(100):
#         out1 = spmm_(sparse_mat, mat, reduce='add')
#         out1 = out1[0] if isinstance(out1, tuple) else out1
#     torch.cuda.synchronize()
#     print('My:   ', time.perf_counter() - t)

#     sparse_mat = torch.sparse_coo_tensor(data.edge_index, value)
#     sparse_mat = sparse_mat.coalesce()

#     torch.cuda.synchronize()
#     t = time.perf_counter()
#     for _ in range(100):
#         out2 = sparse_mat @ mat
#     torch.cuda.synchronize()
#     print('Torch: ', time.perf_counter() - t)

#     torch.cuda.synchronize()
#     t = time.perf_counter()
#     for _ in range(100):
#         spmm(data.edge_index, value, data.num_nodes, data.num_nodes, mat)
#     torch.cuda.synchronize()
#     print('Scatter:', time.perf_counter() - t)

#     assert torch.allclose(out1, out2, atol=1e-2)