add.py 5.01 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
import torch
rusty1s's avatar
rusty1s committed
2
from torch_scatter import gather_csr
rusty1s's avatar
rusty1s committed
3
4


rusty1s's avatar
rusty1s committed
5
6
7
8
def sparse_add(matA, matB):
    nnzA, nnzB = matA.nnz(), matB.nnz()
    valA = torch.full((nnzA, ), 1, dtype=torch.uint8, device=matA.device)
    valB = torch.full((nnzB, ), 2, dtype=torch.uint8, device=matB.device)
rusty1s's avatar
rusty1s committed
9

rusty1s's avatar
rusty1s committed
10
11
12
13
14
15
16
17
18
19
20
21
    if matA.is_cuda:
        pass
    else:
        matA_ = matA.set_value(valA, layout='csr').to_scipy(layout='csr')
        matB_ = matB.set_value(valB, layout='csr').to_scipy(layout='csr')
        matC_ = matA_ + matB_
        rowptr = torch.from_numpy(matC_.indptr).to(torch.long)
        matC_ = matC_.tocoo()
        row = torch.from_numpy(matC_.row).to(torch.long)
        col = torch.from_numpy(matC_.col).to(torch.long)
        index = torch.stack([row, col], dim=0)
        valC_ = torch.from_numpy(matC_.data)
rusty1s's avatar
rusty1s committed
22

rusty1s's avatar
rusty1s committed
23
24
25
    value = None
    if matA.has_value() or matB.has_value():
        maskA, maskB = valC_ != 2, valC_ >= 2
rusty1s's avatar
rusty1s committed
26

rusty1s's avatar
rusty1s committed
27
28
29
30
31
32
        size = matA.size() if matA.dim() >= matB.dim() else matA.size()
        size = (valC_.size(0), ) + size[2:]

        value = torch.zeros(size, dtype=matA.dtype, device=matA.device)
        value[maskA] += matA.storage.value if matA.has_value() else 1
        value[maskB] += matB.storage.value if matB.has_value() else 1
rusty1s's avatar
rusty1s committed
33

rusty1s's avatar
rusty1s committed
34
35
    storage = matA.storage.__class__(index, value, matA.sparse_size(),
                                     rowptr=rowptr, is_sorted=True)
rusty1s's avatar
rusty1s committed
36

rusty1s's avatar
rusty1s committed
37
    return matA.__class__.from_storage(storage)
rusty1s's avatar
rusty1s committed
38
39
40
41
42
43
44


def add(src, other):
    if isinstance(other, int) or isinstance(other, float):
        return add_nnz(src, other)

    elif torch.is_tensor(other):
rusty1s's avatar
add fix  
rusty1s committed
45
        rowptr, col, value = src.csr()
rusty1s's avatar
rusty1s committed
46
        if other.size(0) == src.size(0) and other.size(1) == 1:  # Row-wise...
rusty1s's avatar
add fix  
rusty1s committed
47
            other = gather_csr(other.squeeze(1), rowptr)
rusty1s's avatar
rusty1s committed
48
49
50
51
52
53
54
55
56
57
58
            value = other.add_(src.storage.value if src.has_value() else 1)
            return src.set_value(value, layout='csr')

        if other.size(0) == 1 and other.size(1) == src.size(1):  # Col-wise...
            other = other.squeeze(0)[col]
            value = other.add_(src.storage.value if src.has_value() else 1)
            return src.set_value(value, layout='coo')

        raise ValueError(f'Size mismatch: Expected size ({src.size(0)}, 1,'
                         f' ...) or (1, {src.size(1)}, ...), but got size '
                         f'{other.size()}.')
rusty1s's avatar
rusty1s committed
59
60
61
62
63
64
65
66

    elif isinstance(other, src.__class__):
        raise NotImplementedError

    raise ValueError('Argument `other` needs to be of type `int`, `float`, '
                     '`torch.tensor` or `torch_sparse.SparseTensor`.')


rusty1s's avatar
rusty1s committed
67
def add_(src, other):
rusty1s's avatar
rusty1s committed
68
69
70
71
    if isinstance(other, int) or isinstance(other, float):
        return add_nnz_(src, other)

    elif torch.is_tensor(other):
rusty1s's avatar
add fix  
rusty1s committed
72
        rowptr, col, value = src.csr()
rusty1s's avatar
rusty1s committed
73
        if other.size(0) == src.size(0) and other.size(1) == 1:  # Row-wise...
rusty1s's avatar
add fix  
rusty1s committed
74
            other = gather_csr(other.squeeze(1), rowptr)
rusty1s's avatar
rusty1s committed
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
            if src.has_value():
                value = src.storage.value.add_(other)
            else:
                value = other.add_(1)
            return src.set_value_(value, layout='csr')

        if other.size(0) == 1 and other.size(1) == src.size(1):  # Col-wise...
            other = other.squeeze(0)[col]
            if src.has_value():
                value = src.storage.value.add_(other)
            else:
                value = other.add_(1)
            return src.set_value_(value, layout='coo')

        raise ValueError(f'Size mismatch: Expected size ({src.size(0)}, 1,'
                         f' ...) or (1, {src.size(1)}, ...), but got size '
                         f'{other.size()}.')

    elif isinstance(other, src.__class__):
        raise NotImplementedError

    raise ValueError('Argument `other` needs to be of type `int`, `float`, '
                     '`torch.tensor` or `torch_sparse.SparseTensor`.')
rusty1s's avatar
rusty1s committed
98
99


rusty1s's avatar
rusty1s committed
100
101
def add_nnz(src, other, layout=None):
    if isinstance(other, int) or isinstance(other, float):
rusty1s's avatar
rusty1s committed
102
103
104
105
106
107
108
109
110
111
112
113
        if src.has_value():
            value = src.storage.value + other
        else:
            value = torch.full((src.nnz(), ), 1 + other, device=src.device)
        return src.set_value(value, layout='coo')

    if torch.is_tensor(other):
        if src.has_value():
            value = src.storage.value + other
        else:
            value = other + 1
        return src.set_value(value, layout='coo')
rusty1s's avatar
rusty1s committed
114
115
116

    raise ValueError('Argument `other` needs to be of type `int`, `float` or '
                     '`torch.tensor`.')
rusty1s's avatar
rusty1s committed
117
118
119


def add_nnz_(src, other, layout=None):
rusty1s's avatar
rusty1s committed
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
    if isinstance(other, int) or isinstance(other, float):
        if src.has_value():
            value = src.storage.value.add_(other)
        else:
            value = torch.full((src.nnz(), ), 1 + other, device=src.device)
        return src.set_value_(value, layout='coo')

    if torch.is_tensor(other):
        if src.has_value():
            value = src.storage.value.add_(other)
        else:
            value = other + 1  # No inplace operation possible.
        return src.set_value_(value, layout='coo')

    raise ValueError('Argument `other` needs to be of type `int`, `float` or '
                     '`torch.tensor`.')