matmul.py 3.92 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
from typing import Union, Tuple
rusty1s's avatar
matmul  
rusty1s committed
2

rusty1s's avatar
rusty1s committed
3
import torch
rusty1s's avatar
matmul  
rusty1s committed
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from torch_sparse.tensor import SparseTensor


@torch.jit.script
def spmm_sum(src: SparseTensor, other: torch.Tensor) -> torch.Tensor:
    rowptr, col, value = src.csr()

    row = src.storage._row
    csr2csc = src.storage._csr2csc
    colptr = src.storage._colptr

    if value is not None and value.requires_grad:
        row = src.storage.row()

    if other.requires_grad:
        row = src.storage.row()
        csr2csc = src.storage.csr2csc()
        colptr = src.storage.colptr()

    return torch.ops.torch_sparse.spmm_sum(row, rowptr, col, value, colptr,
                                           csr2csc, other)


@torch.jit.script
def spmm_add(src: SparseTensor, other: torch.Tensor) -> torch.Tensor:
    return spmm_sum(src, other)


32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
@torch.jit.script
def spmm_mean(src: SparseTensor, other: torch.Tensor) -> torch.Tensor:
    rowptr, col, value = src.csr()

    row = src.storage._row
    rowcount = src.storage._rowcount
    csr2csc = src.storage._csr2csc
    colptr = src.storage._colptr

    if value is not None and value.requires_grad:
        row = src.storage.row()

    if other.requires_grad:
        row = src.storage.row()
        rowcount = src.storage.rowcount()
        csr2csc = src.storage.csr2csc()
        colptr = src.storage.colptr()

    return torch.ops.torch_sparse.spmm_mean(row, rowptr, col, value, rowcount,
                                            colptr, csr2csc, other)


rusty1s's avatar
rusty1s committed
54
55
56
57
58
59
60
61
62
63
64
65
66
67
@torch.jit.script
def spmm_min(src: SparseTensor,
             other: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    rowptr, col, value = src.csr()
    return torch.ops.torch_sparse.spmm_min(rowptr, col, value, other)


@torch.jit.script
def spmm_max(src: SparseTensor,
             other: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    rowptr, col, value = src.csr()
    return torch.ops.torch_sparse.spmm_max(rowptr, col, value, other)


rusty1s's avatar
matmul  
rusty1s committed
68
69
70
71
72
@torch.jit.script
def spmm(src: SparseTensor, other: torch.Tensor,
         reduce: str = "sum") -> torch.Tensor:
    if reduce == 'sum' or reduce == 'add':
        return spmm_sum(src, other)
73
74
    elif reduce == 'mean':
        return spmm_mean(src, other)
rusty1s's avatar
rusty1s committed
75
76
77
78
79
80
81
82
83
84
    elif reduce == 'min':
        return spmm_min(src, other)[0]
    elif reduce == 'max':
        return spmm_max(src, other)[0]
    else:
        raise ValueError


@torch.jit.script
def spspmm_sum(src: SparseTensor, other: SparseTensor) -> SparseTensor:
rusty1s's avatar
rusty1s committed
85
    assert src.sparse_size(1) == other.sparse_size(0)
rusty1s's avatar
rusty1s committed
86
87
88
89
90
    rowptrA, colA, valueA = src.csr()
    rowptrB, colB, valueB = other.csr()
    M, K = src.sparse_size(0), other.sparse_size(1)
    rowptrC, colC, valueC = torch.ops.torch_sparse.spspmm_sum(
        rowptrA, colA, valueA, rowptrB, colB, valueB, K)
rusty1s's avatar
rusty1s committed
91
92
93
94
95
96
97
    return SparseTensor(
        row=None,
        rowptr=rowptrC,
        col=colC,
        value=valueC,
        sparse_sizes=torch.Size([M, K]),
        is_sorted=True)
rusty1s's avatar
rusty1s committed
98
99
100
101
102
103
104
105
106
107
108
109
110
111


@torch.jit.script
def spspmm_add(src: SparseTensor, other: SparseTensor) -> SparseTensor:
    return spspmm_sum(src, other)


@torch.jit.script
def spspmm(src: SparseTensor, other: SparseTensor,
           reduce: str = "sum") -> SparseTensor:
    if reduce == 'sum' or reduce == 'add':
        return spspmm_sum(src, other)
    elif reduce == 'mean' or reduce == 'min' or reduce == 'max':
        raise NotImplementedError
rusty1s's avatar
matmul  
rusty1s committed
112
113
114
115
    else:
        raise ValueError


rusty1s's avatar
rusty1s committed
116
117
def matmul(src: SparseTensor,
           other: Union[torch.Tensor, SparseTensor],
rusty1s's avatar
matmul  
rusty1s committed
118
           reduce: str = "sum"):
rusty1s's avatar
rusty1s committed
119
    if torch.is_tensor(other):
rusty1s's avatar
matmul  
rusty1s committed
120
        return spmm(src, other, reduce)
rusty1s's avatar
rusty1s committed
121
122
    elif isinstance(other, SparseTensor):
        return spspmm(src, other, reduce)
rusty1s's avatar
matmul  
rusty1s committed
123
124
125
126
127
    else:
        raise ValueError


SparseTensor.spmm = lambda self, other, reduce=None: spmm(self, other, reduce)
rusty1s's avatar
rusty1s committed
128
129
SparseTensor.spspmm = lambda self, other, reduce=None: spspmm(
    self, other, reduce)
rusty1s's avatar
matmul  
rusty1s committed
130
131
132
SparseTensor.matmul = lambda self, other, reduce=None: matmul(
    self, other, reduce)
SparseTensor.__matmul__ = lambda self, other: matmul(self, other, 'sum')