mul.py 2.6 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
from typing import Optional

rusty1s's avatar
rusty1s committed
3
4
import torch
from torch_scatter import gather_csr
rusty1s's avatar
rusty1s committed
5
6
7
8
9
10
11
from torch_sparse.tensor import SparseTensor


@torch.jit.script
def mul(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
    rowptr, col, value = src.csr()
    if other.size(0) == src.size(0) and other.size(1) == 1:  # Row-wise...
rusty1s's avatar
rusty1s committed
12
        other = gather_csr(other.squeeze(1), rowptr)
rusty1s's avatar
rusty1s committed
13
14
15
16
        pass
    elif other.size(0) == 1 and other.size(1) == src.size(1):  # Col-wise...
        other = other.squeeze(0)[col]
    else:
rusty1s's avatar
rusty1s committed
17
18
19
20
        raise ValueError(f'Size mismatch: Expected size ({src.size(0)}, 1,'
                         f' ...) or (1, {src.size(1)}, ...), but got size '
                         f'{other.size()}.')

rusty1s's avatar
rusty1s committed
21
    if value is not None:
rusty1s's avatar
fixes  
rusty1s committed
22
        value = other.to(value.dtype).mul_(value)
rusty1s's avatar
rusty1s committed
23
24
25
26
27
28
29
30
31
    else:
        value = other
    return src.set_value(value, layout='coo')


@torch.jit.script
def mul_(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
    rowptr, col, value = src.csr()
    if other.size(0) == src.size(0) and other.size(1) == 1:  # Row-wise...
rusty1s's avatar
rusty1s committed
32
        other = gather_csr(other.squeeze(1), rowptr)
rusty1s's avatar
rusty1s committed
33
34
35
36
        pass
    elif other.size(0) == 1 and other.size(1) == src.size(1):  # Col-wise...
        other = other.squeeze(0)[col]
    else:
rusty1s's avatar
rusty1s committed
37
38
39
40
        raise ValueError(f'Size mismatch: Expected size ({src.size(0)}, 1,'
                         f' ...) or (1, {src.size(1)}, ...), but got size '
                         f'{other.size()}.')

rusty1s's avatar
rusty1s committed
41
    if value is not None:
rusty1s's avatar
fixes  
rusty1s committed
42
        value = value.mul_(other.to(value.dtype))
rusty1s's avatar
rusty1s committed
43
44
45
46
47
48
49
50
51
52
    else:
        value = other
    return src.set_value_(value, layout='coo')


@torch.jit.script
def mul_nnz(src: SparseTensor, other: torch.Tensor,
            layout: Optional[str] = None) -> SparseTensor:
    value = src.storage.value()
    if value is not None:
rusty1s's avatar
fixes  
rusty1s committed
53
        value = value.mul(other.to(value.dtype))
rusty1s's avatar
rusty1s committed
54
55
56
57
58
59
60
61
62
63
    else:
        value = other
    return src.set_value(value, layout=layout)


@torch.jit.script
def mul_nnz_(src: SparseTensor, other: torch.Tensor,
             layout: Optional[str] = None) -> SparseTensor:
    value = src.storage.value()
    if value is not None:
rusty1s's avatar
fixes  
rusty1s committed
64
        value = value.mul_(other.to(value.dtype))
rusty1s's avatar
rusty1s committed
65
66
67
68
69
70
71
72
73
74
75
    else:
        value = other
    return src.set_value_(value, layout=layout)


SparseTensor.mul = lambda self, other: mul(self, other)
SparseTensor.mul_ = lambda self, other: mul_(self, other)
SparseTensor.mul_nnz = lambda self, other, layout=None: mul_nnz(
    self, other, layout)
SparseTensor.mul_nnz_ = lambda self, other, layout=None: mul_nnz_(
    self, other, layout)
rusty1s's avatar
fixes  
rusty1s committed
76
77
78
SparseTensor.__mul__ = SparseTensor.mul
SparseTensor.__rmul__ = SparseTensor.mul
SparseTensor.__imul__ = SparseTensor.mul_