matmul.py 4.28 KB
Newer Older
Matthias Fey's avatar
Matthias Fey committed
1
from typing import Optional, Tuple
rusty1s's avatar
matmul  
rusty1s committed
2

rusty1s's avatar
rusty1s committed
3
import torch
Matthias Fey's avatar
Matthias Fey committed
4
from torch import Tensor
rusty1s's avatar
rusty1s committed
5

rusty1s's avatar
matmul  
rusty1s committed
6
7
8
9
10
11
12
13
14
15
from torch_sparse.tensor import SparseTensor


def spmm_sum(src: SparseTensor, other: torch.Tensor) -> torch.Tensor:
    rowptr, col, value = src.csr()

    row = src.storage._row
    csr2csc = src.storage._csr2csc
    colptr = src.storage._colptr

rusty1s's avatar
rusty1s committed
16
17
18
    if value is not None:
        value = value.to(other.dtype)

rusty1s's avatar
matmul  
rusty1s committed
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
    if value is not None and value.requires_grad:
        row = src.storage.row()

    if other.requires_grad:
        row = src.storage.row()
        csr2csc = src.storage.csr2csc()
        colptr = src.storage.colptr()

    return torch.ops.torch_sparse.spmm_sum(row, rowptr, col, value, colptr,
                                           csr2csc, other)


def spmm_add(src: SparseTensor, other: torch.Tensor) -> torch.Tensor:
    return spmm_sum(src, other)


35
36
37
38
39
40
41
42
def spmm_mean(src: SparseTensor, other: torch.Tensor) -> torch.Tensor:
    rowptr, col, value = src.csr()

    row = src.storage._row
    rowcount = src.storage._rowcount
    csr2csc = src.storage._csr2csc
    colptr = src.storage._colptr

rusty1s's avatar
rusty1s committed
43
44
45
    if value is not None:
        value = value.to(other.dtype)

46
47
48
49
50
51
52
53
54
55
56
57
58
    if value is not None and value.requires_grad:
        row = src.storage.row()

    if other.requires_grad:
        row = src.storage.row()
        rowcount = src.storage.rowcount()
        csr2csc = src.storage.csr2csc()
        colptr = src.storage.colptr()

    return torch.ops.torch_sparse.spmm_mean(row, rowptr, col, value, rowcount,
                                            colptr, csr2csc, other)


rusty1s's avatar
rusty1s committed
59
60
61
def spmm_min(src: SparseTensor,
             other: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    rowptr, col, value = src.csr()
rusty1s's avatar
rusty1s committed
62
63
64
65

    if value is not None:
        value = value.to(other.dtype)

rusty1s's avatar
rusty1s committed
66
67
68
69
70
71
    return torch.ops.torch_sparse.spmm_min(rowptr, col, value, other)


def spmm_max(src: SparseTensor,
             other: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    rowptr, col, value = src.csr()
rusty1s's avatar
rusty1s committed
72
73
74
75

    if value is not None:
        value = value.to(other.dtype)

rusty1s's avatar
rusty1s committed
76
77
78
    return torch.ops.torch_sparse.spmm_max(rowptr, col, value, other)


rusty1s's avatar
matmul  
rusty1s committed
79
80
81
82
def spmm(src: SparseTensor, other: torch.Tensor,
         reduce: str = "sum") -> torch.Tensor:
    if reduce == 'sum' or reduce == 'add':
        return spmm_sum(src, other)
83
84
    elif reduce == 'mean':
        return spmm_mean(src, other)
rusty1s's avatar
rusty1s committed
85
86
87
88
89
90
91
92
93
    elif reduce == 'min':
        return spmm_min(src, other)[0]
    elif reduce == 'max':
        return spmm_max(src, other)[0]
    else:
        raise ValueError


def spspmm_sum(src: SparseTensor, other: SparseTensor) -> SparseTensor:
Matthias Fey's avatar
Matthias Fey committed
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
    A = src.to_torch_sparse_coo_tensor()
    B = other.to_torch_sparse_coo_tensor()
    C = torch.sparse.mm(A, B)
    edge_index = C._indices()
    row, col = edge_index[0], edge_index[1]
    value: Optional[Tensor] = None
    if src.has_value() and other.has_value():
        value = C._values()

    return SparseTensor(
        row=row,
        col=col,
        value=value,
        sparse_sizes=(C.size(0), C.size(1)),
        is_sorted=True,
        trust_data=True,
    )
rusty1s's avatar
rusty1s committed
111
112
113
114
115
116
117
118
119
120
121
122


def spspmm_add(src: SparseTensor, other: SparseTensor) -> SparseTensor:
    return spspmm_sum(src, other)


def spspmm(src: SparseTensor, other: SparseTensor,
           reduce: str = "sum") -> SparseTensor:
    if reduce == 'sum' or reduce == 'add':
        return spspmm_sum(src, other)
    elif reduce == 'mean' or reduce == 'min' or reduce == 'max':
        raise NotImplementedError
rusty1s's avatar
matmul  
rusty1s committed
123
124
125
126
    else:
        raise ValueError


rusty1s's avatar
rusty1s committed
127
@torch.jit._overload  # noqa: F811
rusty1s's avatar
rusty1s committed
128
def matmul(src, other, reduce):  # noqa: F811
rusty1s's avatar
rusty1s committed
129
130
131
132
133
    # type: (SparseTensor, torch.Tensor, str) -> torch.Tensor
    pass


@torch.jit._overload  # noqa: F811
rusty1s's avatar
rusty1s committed
134
def matmul(src, other, reduce):  # noqa: F811
rusty1s's avatar
rusty1s committed
135
136
137
138
139
140
    # type: (SparseTensor, SparseTensor, str) -> SparseTensor
    pass


def matmul(src, other, reduce="sum"):  # noqa: F811
    if isinstance(other, torch.Tensor):
rusty1s's avatar
matmul  
rusty1s committed
141
        return spmm(src, other, reduce)
rusty1s's avatar
rusty1s committed
142
143
    elif isinstance(other, SparseTensor):
        return spspmm(src, other, reduce)
rusty1s's avatar
rusty1s committed
144
    raise ValueError
rusty1s's avatar
matmul  
rusty1s committed
145
146


rusty1s's avatar
rusty1s committed
147
148
SparseTensor.spmm = lambda self, other, reduce="sum": spmm(self, other, reduce)
SparseTensor.spspmm = lambda self, other, reduce="sum": spspmm(
rusty1s's avatar
rusty1s committed
149
    self, other, reduce)
rusty1s's avatar
rusty1s committed
150
SparseTensor.matmul = lambda self, other, reduce="sum": matmul(
rusty1s's avatar
matmul  
rusty1s committed
151
152
    self, other, reduce)
SparseTensor.__matmul__ = lambda self, other: matmul(self, other, 'sum')