view.py 1.6 KB
Newer Older
Mario Geiger's avatar
view  
Mario Geiger committed
1
2
3
4
5
6
7
8
9
10
import torch

from torch_sparse.storage import SparseStorage
from torch_sparse.tensor import SparseTensor


def _view(src: SparseTensor, n: int, layout: str = 'csr') -> SparseTensor:
    row, col, value = src.coo()
    sparse_sizes = src.storage.sparse_sizes()

Mario Geiger's avatar
fix  
Mario Geiger committed
11
    if sparse_sizes[0] * sparse_sizes[1] % n != 0:
Mario Geiger's avatar
view  
Mario Geiger committed
12
        raise RuntimeError(
Mario Geiger's avatar
fix  
Mario Geiger committed
13
14
            f"shape '[-1, {n}]' is invalid for input of size "
            f"{sparse_sizes[0] * sparse_sizes[1]}")
Mario Geiger's avatar
view  
Mario Geiger committed
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59

    assert layout == 'csr' or layout == 'csc'

    if layout == 'csr':
        idx = sparse_sizes[1] * row + col
        row = idx // n
        col = idx % n
        sparse_sizes = (sparse_sizes[0] * sparse_sizes[1] // n, n)
    if layout == 'csc':
        idx = sparse_sizes[0] * col + row
        row = idx % n
        col = idx // n
        sparse_sizes = (n, sparse_sizes[0] * sparse_sizes[1] // n)

    storage = SparseStorage(
        row=row,
        rowptr=src.storage._rowptr,
        col=col,
        value=value,
        sparse_sizes=sparse_sizes,
        rowcount=src.storage._rowcount,
        colptr=src.storage._colptr,
        colcount=src.storage._colcount,
        csr2csc=src.storage._csr2csc,
        csc2csr=src.storage._csc2csr,
        is_sorted=True,
    )

    return src.from_storage(storage)


SparseTensor.view = lambda self, m, n: _view(self, n, layout='csr')

###############################################################################


def view(index, value, m, n, new_n):
    assert m * n % new_n == 0

    row, col = index
    idx = n * row + col
    row = idx // new_n
    col = idx % new_n

    return torch.stack([row, col], dim=0), value