add.py 4.96 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
import torch
rusty1s's avatar
rusty1s committed
2
from torch_scatter import gather_csr
rusty1s's avatar
rusty1s committed
3
4


rusty1s's avatar
rusty1s committed
5
6
7
8
def is_scalar(other):
    return isinstance(other, int) or isinstance(other, float)


rusty1s's avatar
rusty1s committed
9
10
11
12
def sparse_add(matA, matB):
    nnzA, nnzB = matA.nnz(), matB.nnz()
    valA = torch.full((nnzA, ), 1, dtype=torch.uint8, device=matA.device)
    valB = torch.full((nnzB, ), 2, dtype=torch.uint8, device=matB.device)
rusty1s's avatar
rusty1s committed
13

rusty1s's avatar
rusty1s committed
14
15
16
17
18
19
20
21
22
23
24
25
    if matA.is_cuda:
        pass
    else:
        matA_ = matA.set_value(valA, layout='csr').to_scipy(layout='csr')
        matB_ = matB.set_value(valB, layout='csr').to_scipy(layout='csr')
        matC_ = matA_ + matB_
        rowptr = torch.from_numpy(matC_.indptr).to(torch.long)
        matC_ = matC_.tocoo()
        row = torch.from_numpy(matC_.row).to(torch.long)
        col = torch.from_numpy(matC_.col).to(torch.long)
        index = torch.stack([row, col], dim=0)
        valC_ = torch.from_numpy(matC_.data)
rusty1s's avatar
rusty1s committed
26

rusty1s's avatar
rusty1s committed
27
28
29
    value = None
    if matA.has_value() or matB.has_value():
        maskA, maskB = valC_ != 2, valC_ >= 2
rusty1s's avatar
rusty1s committed
30

rusty1s's avatar
rusty1s committed
31
32
33
34
35
36
        size = matA.size() if matA.dim() >= matB.dim() else matA.size()
        size = (valC_.size(0), ) + size[2:]

        value = torch.zeros(size, dtype=matA.dtype, device=matA.device)
        value[maskA] += matA.storage.value if matA.has_value() else 1
        value[maskB] += matB.storage.value if matB.has_value() else 1
rusty1s's avatar
rusty1s committed
37

rusty1s's avatar
rusty1s committed
38
39
    storage = matA.storage.__class__(index, value, matA.sparse_size(),
                                     rowptr=rowptr, is_sorted=True)
rusty1s's avatar
rusty1s committed
40

rusty1s's avatar
rusty1s committed
41
    return matA.__class__.from_storage(storage)
rusty1s's avatar
rusty1s committed
42
43
44


def add(src, other):
rusty1s's avatar
rusty1s committed
45
    if is_scalar(other):
rusty1s's avatar
rusty1s committed
46
47
48
        return add_nnz(src, other)

    elif torch.is_tensor(other):
rusty1s's avatar
add fix  
rusty1s committed
49
        rowptr, col, value = src.csr()
rusty1s's avatar
rusty1s committed
50
        if other.size(0) == src.size(0) and other.size(1) == 1:  # Row-wise...
rusty1s's avatar
add fix  
rusty1s committed
51
            other = gather_csr(other.squeeze(1), rowptr)
rusty1s's avatar
rusty1s committed
52
53
54
55
56
57
58
59
60
61
62
            value = other.add_(src.storage.value if src.has_value() else 1)
            return src.set_value(value, layout='csr')

        if other.size(0) == 1 and other.size(1) == src.size(1):  # Col-wise...
            other = other.squeeze(0)[col]
            value = other.add_(src.storage.value if src.has_value() else 1)
            return src.set_value(value, layout='coo')

        raise ValueError(f'Size mismatch: Expected size ({src.size(0)}, 1,'
                         f' ...) or (1, {src.size(1)}, ...), but got size '
                         f'{other.size()}.')
rusty1s's avatar
rusty1s committed
63
64
65
66
67
68
69
70

    elif isinstance(other, src.__class__):
        raise NotImplementedError

    raise ValueError('Argument `other` needs to be of type `int`, `float`, '
                     '`torch.tensor` or `torch_sparse.SparseTensor`.')


rusty1s's avatar
rusty1s committed
71
def add_(src, other):
rusty1s's avatar
rusty1s committed
72
    if is_scalar(other):
rusty1s's avatar
rusty1s committed
73
74
75
        return add_nnz_(src, other)

    elif torch.is_tensor(other):
rusty1s's avatar
add fix  
rusty1s committed
76
        rowptr, col, value = src.csr()
rusty1s's avatar
rusty1s committed
77
        if other.size(0) == src.size(0) and other.size(1) == 1:  # Row-wise...
rusty1s's avatar
add fix  
rusty1s committed
78
            other = gather_csr(other.squeeze(1), rowptr)
rusty1s's avatar
rusty1s committed
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
            if src.has_value():
                value = src.storage.value.add_(other)
            else:
                value = other.add_(1)
            return src.set_value_(value, layout='csr')

        if other.size(0) == 1 and other.size(1) == src.size(1):  # Col-wise...
            other = other.squeeze(0)[col]
            if src.has_value():
                value = src.storage.value.add_(other)
            else:
                value = other.add_(1)
            return src.set_value_(value, layout='coo')

        raise ValueError(f'Size mismatch: Expected size ({src.size(0)}, 1,'
                         f' ...) or (1, {src.size(1)}, ...), but got size '
                         f'{other.size()}.')

    elif isinstance(other, src.__class__):
        raise NotImplementedError

    raise ValueError('Argument `other` needs to be of type `int`, `float`, '
                     '`torch.tensor` or `torch_sparse.SparseTensor`.')
rusty1s's avatar
rusty1s committed
102
103


rusty1s's avatar
rusty1s committed
104
def add_nnz(src, other, layout=None):
rusty1s's avatar
rusty1s committed
105
    if is_scalar(other):
rusty1s's avatar
rusty1s committed
106
107
108
109
110
111
112
113
114
115
116
117
        if src.has_value():
            value = src.storage.value + other
        else:
            value = torch.full((src.nnz(), ), 1 + other, device=src.device)
        return src.set_value(value, layout='coo')

    if torch.is_tensor(other):
        if src.has_value():
            value = src.storage.value + other
        else:
            value = other + 1
        return src.set_value(value, layout='coo')
rusty1s's avatar
rusty1s committed
118
119
120

    raise ValueError('Argument `other` needs to be of type `int`, `float` or '
                     '`torch.tensor`.')
rusty1s's avatar
rusty1s committed
121
122
123


def add_nnz_(src, other, layout=None):
rusty1s's avatar
rusty1s committed
124
    if is_scalar(other):
rusty1s's avatar
rusty1s committed
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
        if src.has_value():
            value = src.storage.value.add_(other)
        else:
            value = torch.full((src.nnz(), ), 1 + other, device=src.device)
        return src.set_value_(value, layout='coo')

    if torch.is_tensor(other):
        if src.has_value():
            value = src.storage.value.add_(other)
        else:
            value = other + 1  # No inplace operation possible.
        return src.set_value_(value, layout='coo')

    raise ValueError('Argument `other` needs to be of type `int`, `float` or '
                     '`torch.tensor`.')