add.py 4.92 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
import torch
rusty1s's avatar
rusty1s committed
2
from torch_scatter import gather_csr
rusty1s's avatar
rusty1s committed
3
from torch_sparse.utils import is_scalar
rusty1s's avatar
rusty1s committed
4
5


rusty1s's avatar
rusty1s committed
6
7
8
9
def sparse_add(matA, matB):
    nnzA, nnzB = matA.nnz(), matB.nnz()
    valA = torch.full((nnzA, ), 1, dtype=torch.uint8, device=matA.device)
    valB = torch.full((nnzB, ), 2, dtype=torch.uint8, device=matB.device)
rusty1s's avatar
rusty1s committed
10

rusty1s's avatar
rusty1s committed
11
12
13
14
15
16
17
18
19
20
21
22
    if matA.is_cuda:
        pass
    else:
        matA_ = matA.set_value(valA, layout='csr').to_scipy(layout='csr')
        matB_ = matB.set_value(valB, layout='csr').to_scipy(layout='csr')
        matC_ = matA_ + matB_
        rowptr = torch.from_numpy(matC_.indptr).to(torch.long)
        matC_ = matC_.tocoo()
        row = torch.from_numpy(matC_.row).to(torch.long)
        col = torch.from_numpy(matC_.col).to(torch.long)
        index = torch.stack([row, col], dim=0)
        valC_ = torch.from_numpy(matC_.data)
rusty1s's avatar
rusty1s committed
23

rusty1s's avatar
rusty1s committed
24
25
26
    value = None
    if matA.has_value() or matB.has_value():
        maskA, maskB = valC_ != 2, valC_ >= 2
rusty1s's avatar
rusty1s committed
27

rusty1s's avatar
rusty1s committed
28
29
30
31
32
33
        size = matA.size() if matA.dim() >= matB.dim() else matA.size()
        size = (valC_.size(0), ) + size[2:]

        value = torch.zeros(size, dtype=matA.dtype, device=matA.device)
        value[maskA] += matA.storage.value if matA.has_value() else 1
        value[maskB] += matB.storage.value if matB.has_value() else 1
rusty1s's avatar
rusty1s committed
34

rusty1s's avatar
rusty1s committed
35
36
    storage = matA.storage.__class__(index, value, matA.sparse_size(),
                                     rowptr=rowptr, is_sorted=True)
rusty1s's avatar
rusty1s committed
37

rusty1s's avatar
rusty1s committed
38
    return matA.__class__.from_storage(storage)
rusty1s's avatar
rusty1s committed
39
40
41


def add(src, other):
rusty1s's avatar
rusty1s committed
42
    if is_scalar(other):
rusty1s's avatar
rusty1s committed
43
44
45
        return add_nnz(src, other)

    elif torch.is_tensor(other):
rusty1s's avatar
add fix  
rusty1s committed
46
        rowptr, col, value = src.csr()
rusty1s's avatar
rusty1s committed
47
        if other.size(0) == src.size(0) and other.size(1) == 1:  # Row-wise...
rusty1s's avatar
add fix  
rusty1s committed
48
            other = gather_csr(other.squeeze(1), rowptr)
rusty1s's avatar
rusty1s committed
49
50
51
52
53
54
55
56
57
58
59
            value = other.add_(src.storage.value if src.has_value() else 1)
            return src.set_value(value, layout='csr')

        if other.size(0) == 1 and other.size(1) == src.size(1):  # Col-wise...
            other = other.squeeze(0)[col]
            value = other.add_(src.storage.value if src.has_value() else 1)
            return src.set_value(value, layout='coo')

        raise ValueError(f'Size mismatch: Expected size ({src.size(0)}, 1,'
                         f' ...) or (1, {src.size(1)}, ...), but got size '
                         f'{other.size()}.')
rusty1s's avatar
rusty1s committed
60
61
62
63
64
65
66
67

    elif isinstance(other, src.__class__):
        raise NotImplementedError

    raise ValueError('Argument `other` needs to be of type `int`, `float`, '
                     '`torch.tensor` or `torch_sparse.SparseTensor`.')


rusty1s's avatar
rusty1s committed
68
def add_(src, other):
rusty1s's avatar
rusty1s committed
69
    if is_scalar(other):
rusty1s's avatar
rusty1s committed
70
71
72
        return add_nnz_(src, other)

    elif torch.is_tensor(other):
rusty1s's avatar
add fix  
rusty1s committed
73
        rowptr, col, value = src.csr()
rusty1s's avatar
rusty1s committed
74
        if other.size(0) == src.size(0) and other.size(1) == 1:  # Row-wise...
rusty1s's avatar
add fix  
rusty1s committed
75
            other = gather_csr(other.squeeze(1), rowptr)
rusty1s's avatar
rusty1s committed
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
            if src.has_value():
                value = src.storage.value.add_(other)
            else:
                value = other.add_(1)
            return src.set_value_(value, layout='csr')

        if other.size(0) == 1 and other.size(1) == src.size(1):  # Col-wise...
            other = other.squeeze(0)[col]
            if src.has_value():
                value = src.storage.value.add_(other)
            else:
                value = other.add_(1)
            return src.set_value_(value, layout='coo')

        raise ValueError(f'Size mismatch: Expected size ({src.size(0)}, 1,'
                         f' ...) or (1, {src.size(1)}, ...), but got size '
                         f'{other.size()}.')

    elif isinstance(other, src.__class__):
        raise NotImplementedError

    raise ValueError('Argument `other` needs to be of type `int`, `float`, '
                     '`torch.tensor` or `torch_sparse.SparseTensor`.')
rusty1s's avatar
rusty1s committed
99
100


rusty1s's avatar
rusty1s committed
101
def add_nnz(src, other, layout=None):
rusty1s's avatar
rusty1s committed
102
    if is_scalar(other):
rusty1s's avatar
rusty1s committed
103
104
105
106
107
108
109
110
111
112
113
114
        if src.has_value():
            value = src.storage.value + other
        else:
            value = torch.full((src.nnz(), ), 1 + other, device=src.device)
        return src.set_value(value, layout='coo')

    if torch.is_tensor(other):
        if src.has_value():
            value = src.storage.value + other
        else:
            value = other + 1
        return src.set_value(value, layout='coo')
rusty1s's avatar
rusty1s committed
115
116
117

    raise ValueError('Argument `other` needs to be of type `int`, `float` or '
                     '`torch.tensor`.')
rusty1s's avatar
rusty1s committed
118
119
120


def add_nnz_(src, other, layout=None):
rusty1s's avatar
rusty1s committed
121
    if is_scalar(other):
rusty1s's avatar
rusty1s committed
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
        if src.has_value():
            value = src.storage.value.add_(other)
        else:
            value = torch.full((src.nnz(), ), 1 + other, device=src.device)
        return src.set_value_(value, layout='coo')

    if torch.is_tensor(other):
        if src.has_value():
            value = src.storage.value.add_(other)
        else:
            value = other + 1  # No inplace operation possible.
        return src.set_value_(value, layout='coo')

    raise ValueError('Argument `other` needs to be of type `int`, `float` or '
                     '`torch.tensor`.')