add.py 3.5 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
from typing import Optional

rusty1s's avatar
rusty1s committed
3
import torch
4
from torch import Tensor
rusty1s's avatar
rusty1s committed
5
from torch_scatter import gather_csr
rusty1s's avatar
rusty1s committed
6
from torch_sparse.tensor import SparseTensor
rusty1s's avatar
rusty1s committed
7

rusty1s's avatar
rusty1s committed
8

9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
@torch.jit._overload  # noqa: F811
def add(src, other):  # noqa: F811
    # type: (SparseTensor, Tensor) -> SparseTensor
    pass


@torch.jit._overload  # noqa: F811
def add(src, other):  # noqa: F811
    # type: (SparseTensor, SparseTensor) -> SparseTensor
    pass


def add(src, other):  # noqa: F811
    if isinstance(other, Tensor):
        rowptr, col, value = src.csr()
        if other.size(0) == src.size(0) and other.size(1) == 1:  # Row-wise.
            other = gather_csr(other.squeeze(1), rowptr)
        elif other.size(0) == 1 and other.size(1) == src.size(1):  # Col-wise.
            other = other.squeeze(0)[col]
        else:
            raise ValueError(
                f'Size mismatch: Expected size ({src.size(0)}, 1, ...) or '
                f'(1, {src.size(1)}, ...), but got size {other.size()}.')
        if value is not None:
            value = other.to(value.dtype).add_(value)
        else:
            value = other.add_(1)
        return src.set_value(value, layout='coo')

    elif isinstance(other, SparseTensor):
        rowA, colA, valueA = src.coo()
        rowB, colB, valueB = other.coo()

        row = torch.cat([rowA, rowB], dim=0)
        col = torch.cat([colA, colB], dim=0)

        value: Optional[Tensor] = None
        if valueA is not None and valueB is not None:
            value = torch.cat([valueA, valueB], dim=0)

        M = max(src.size(0), other.size(0))
        N = max(src.size(1), other.size(1))
        sparse_sizes = (M, N)

        out = SparseTensor(row=row, col=col, value=value,
                           sparse_sizes=sparse_sizes)
        out = out.coalesce(reduce='sum')
        return out

rusty1s's avatar
rusty1s committed
58
    else:
59
        raise NotImplementedError
rusty1s's avatar
rusty1s committed
60
61


rusty1s's avatar
rusty1s committed
62
63
def add_(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
    rowptr, col, value = src.csr()
64
    if other.size(0) == src.size(0) and other.size(1) == 1:  # Row-wise.
rusty1s's avatar
rusty1s committed
65
        other = gather_csr(other.squeeze(1), rowptr)
66
    elif other.size(0) == 1 and other.size(1) == src.size(1):  # Col-wise.
rusty1s's avatar
rusty1s committed
67
68
        other = other.squeeze(0)[col]
    else:
rusty1s's avatar
rusty1s committed
69
        raise ValueError(
rusty1s's avatar
rusty1s committed
70
71
            f'Size mismatch: Expected size ({src.size(0)}, 1, ...) or '
            f'(1, {src.size(1)}, ...), but got size {other.size()}.')
rusty1s's avatar
rusty1s committed
72

rusty1s's avatar
rusty1s committed
73
    if value is not None:
rusty1s's avatar
fixes  
rusty1s committed
74
        value = value.add_(other.to(value.dtype))
rusty1s's avatar
rusty1s committed
75
76
77
    else:
        value = other.add_(1)
    return src.set_value_(value, layout='coo')
rusty1s's avatar
rusty1s committed
78

rusty1s's avatar
rusty1s committed
79

rusty1s's avatar
rusty1s committed
80
81
82
83
def add_nnz(src: SparseTensor, other: torch.Tensor,
            layout: Optional[str] = None) -> SparseTensor:
    value = src.storage.value()
    if value is not None:
rusty1s's avatar
fixes  
rusty1s committed
84
        value = value.add(other.to(value.dtype))
rusty1s's avatar
rusty1s committed
85
86
87
    else:
        value = other.add(1)
    return src.set_value(value, layout=layout)
rusty1s's avatar
rusty1s committed
88
89


rusty1s's avatar
rusty1s committed
90
91
92
93
def add_nnz_(src: SparseTensor, other: torch.Tensor,
             layout: Optional[str] = None) -> SparseTensor:
    value = src.storage.value()
    if value is not None:
rusty1s's avatar
fixes  
rusty1s committed
94
        value = value.add_(other.to(value.dtype))
rusty1s's avatar
rusty1s committed
95
96
97
    else:
        value = other.add(1)
    return src.set_value_(value, layout=layout)
rusty1s's avatar
rusty1s committed
98
99


rusty1s's avatar
rusty1s committed
100
101
102
103
104
105
SparseTensor.add = lambda self, other: add(self, other)
SparseTensor.add_ = lambda self, other: add_(self, other)
SparseTensor.add_nnz = lambda self, other, layout=None: add_nnz(
    self, other, layout)
SparseTensor.add_nnz_ = lambda self, other, layout=None: add_nnz_(
    self, other, layout)
rusty1s's avatar
fixes  
rusty1s committed
106
107
108
SparseTensor.__add__ = SparseTensor.add
SparseTensor.__radd__ = SparseTensor.add
SparseTensor.__iadd__ = SparseTensor.add_