"src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py" did not exist on "8d81564b27956dbabeeac833139aab27e60e379d"
masked_select.py 2.32 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import torch

from torch_sparse.storage import get_layout


def masked_select(src, dim, mask):
    dim = src.dim() - dim if dim < 0 else dim

    assert mask.dim() == 1
    storage = src.storage

    if dim == 0:
        (row, col), value = src.coo()
        rowcount = src.storage.rowcount

        row_mask = mask[row]
        rowcount = rowcount[mask]
        idx = torch.arange(rowcount.size(0), device=rowcount.device)
        row = idx.repeat_interleave(rowcount)
        col = col[row_mask]
        index = torch.stack([row, col], dim=0)

        if src.has_value():
            value = value[row_mask]

        sparse_size = torch.Size([rowcount.size(0), src.sparse_size(1)])

        storage = src.storage.__class__(index, value, sparse_size,
                                        rowcount=rowcount, is_sorted=True)

    elif dim == 1:
        csr2csc = src.storage.csr2csc
        row = src.storage.row[csr2csc]
        col = src.storage.col[csr2csc]
        colcount = src.storage.colcount

        col_mask = mask[col]
        colcount = colcount[mask]
        tmp = torch.arange(colcount.size(0), device=row.device)
        col = tmp.repeat_interleave(colcount)
        row = row[col_mask]
        csc2csr = (colcount.size(0) * row + col).argsort()
        index = torch.stack([row, col], dim=0)[:, csc2csr]

        value = src.storage.value
        if src.has_value():
            value = value[csr2csc][col_mask][csc2csr]

        sparse_size = torch.Size([src.sparse_size(0), colcount.size(0)])

        storage = src.storage.__class__(index, value, sparse_size,
                                        colcount=colcount, csc2csr=csc2csr,
                                        is_sorted=True)

    else:
        idx = mask.nonzero().view(-1)
        storage = src.storage.apply_value(
            lambda x: x.index_select(dim - 1, idx))

    return src.from_storage(storage)


def masked_select_nnz(src, mask, layout=None):
    assert mask.dim() == 1

    if get_layout(layout) == 'csc':
        mask = mask[src.storage.csc2csr]

    index, value = src.coo()

    index = index[:, mask]
    if src.has_value():
        value = value[mask]

    # There is no other information we can maintain...
    storage = src.storage.__class__(index, value, src.sparse_size(),
                                    is_sorted=True)

    return src.from_storage(storage)