add.py 2.57 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
from typing import Optional

rusty1s's avatar
rusty1s committed
3
import torch
rusty1s's avatar
rusty1s committed
4
from torch_scatter import gather_csr
rusty1s's avatar
rusty1s committed
5
from torch_sparse.tensor import SparseTensor
rusty1s's avatar
rusty1s committed
6

rusty1s's avatar
rusty1s committed
7

rusty1s's avatar
rusty1s committed
8
9
10
11
@torch.jit.script
def add(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
    rowptr, col, value = src.csr()
    if other.size(0) == src.size(0) and other.size(1) == 1:  # Row-wise...
rusty1s's avatar
rusty1s committed
12
        other = gather_csr(other.squeeze(1), rowptr)
rusty1s's avatar
rusty1s committed
13
        pass
rusty1s's avatar
rusty1s committed
14
15
    elif other.size(0) == 1 and other.size(1) == src.size(1):  # Col-wise...
        other = other.squeeze(0)[col]
rusty1s's avatar
rusty1s committed
16
    else:
rusty1s's avatar
rusty1s committed
17
        raise ValueError(
rusty1s's avatar
rusty1s committed
18
19
            f'Size mismatch: Expected size ({src.size(0)}, 1, ...) or '
            f'(1, {src.size(1)}, ...), but got size {other.size()}.')
rusty1s's avatar
rusty1s committed
20
    if value is not None:
rusty1s's avatar
fixes  
rusty1s committed
21
        value = other.to(value.dtype).add_(value)
rusty1s's avatar
rusty1s committed
22
23
24
    else:
        value = other.add_(1)
    return src.set_value(value, layout='coo')
rusty1s's avatar
rusty1s committed
25
26


rusty1s's avatar
rusty1s committed
27
28
29
30
@torch.jit.script
def add_(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
    rowptr, col, value = src.csr()
    if other.size(0) == src.size(0) and other.size(1) == 1:  # Row-wise...
rusty1s's avatar
rusty1s committed
31
        other = gather_csr(other.squeeze(1), rowptr)
rusty1s's avatar
rusty1s committed
32
33
34
35
        pass
    elif other.size(0) == 1 and other.size(1) == src.size(1):  # Col-wise...
        other = other.squeeze(0)[col]
    else:
rusty1s's avatar
rusty1s committed
36
        raise ValueError(
rusty1s's avatar
rusty1s committed
37
38
            f'Size mismatch: Expected size ({src.size(0)}, 1, ...) or '
            f'(1, {src.size(1)}, ...), but got size {other.size()}.')
rusty1s's avatar
rusty1s committed
39

rusty1s's avatar
rusty1s committed
40
    if value is not None:
rusty1s's avatar
fixes  
rusty1s committed
41
        value = value.add_(other.to(value.dtype))
rusty1s's avatar
rusty1s committed
42
43
44
    else:
        value = other.add_(1)
    return src.set_value_(value, layout='coo')
rusty1s's avatar
rusty1s committed
45

rusty1s's avatar
rusty1s committed
46

rusty1s's avatar
rusty1s committed
47
48
49
50
51
@torch.jit.script
def add_nnz(src: SparseTensor, other: torch.Tensor,
            layout: Optional[str] = None) -> SparseTensor:
    value = src.storage.value()
    if value is not None:
rusty1s's avatar
fixes  
rusty1s committed
52
        value = value.add(other.to(value.dtype))
rusty1s's avatar
rusty1s committed
53
54
55
    else:
        value = other.add(1)
    return src.set_value(value, layout=layout)
rusty1s's avatar
rusty1s committed
56
57


rusty1s's avatar
rusty1s committed
58
59
60
61
62
@torch.jit.script
def add_nnz_(src: SparseTensor, other: torch.Tensor,
             layout: Optional[str] = None) -> SparseTensor:
    value = src.storage.value()
    if value is not None:
rusty1s's avatar
fixes  
rusty1s committed
63
        value = value.add_(other.to(value.dtype))
rusty1s's avatar
rusty1s committed
64
65
66
    else:
        value = other.add(1)
    return src.set_value_(value, layout=layout)
rusty1s's avatar
rusty1s committed
67
68


rusty1s's avatar
rusty1s committed
69
70
71
72
73
74
SparseTensor.add = lambda self, other: add(self, other)
SparseTensor.add_ = lambda self, other: add_(self, other)
SparseTensor.add_nnz = lambda self, other, layout=None: add_nnz(
    self, other, layout)
SparseTensor.add_nnz_ = lambda self, other, layout=None: add_nnz_(
    self, other, layout)
rusty1s's avatar
fixes  
rusty1s committed
75
76
77
SparseTensor.__add__ = SparseTensor.add
SparseTensor.__radd__ = SparseTensor.add
SparseTensor.__iadd__ = SparseTensor.add_