test_matrix_op.py 1.28 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import backend as F
import pytest
import torch

from .utils import (
    rand_coo,
    rand_csc,
    rand_csr,
    rand_diag,
    sparse_matrix_to_dense,
)


@pytest.mark.parametrize(
    "create_func", [rand_diag, rand_csr, rand_csc, rand_coo]
)
@pytest.mark.parametrize("dim", [0, 1])
@pytest.mark.parametrize("index", [None, (1, 3), (4, 0, 2)])
def test_compact(create_func, dim, index):
    ctx = F.ctx()
    shape = (5, 5)
    ans_idx = []
    if index is not None:
        ans_idx = list(dict.fromkeys(index))
        index = torch.tensor(index).to(ctx)

    A = create_func(shape, 8, ctx)

    A_compact, ret_id = A.compact(dim, index)
    A_compact_dense = sparse_matrix_to_dense(A_compact)

    A_dense = sparse_matrix_to_dense(A)

    for i in range(shape[dim]):
        if dim == 0:
            row = list(A_dense[i, :].nonzero().reshape(-1))
        else:
            row = list(A_dense[:, i].nonzero().reshape(-1))
        if (i not in list(ans_idx)) and len(row) > 0:
            ans_idx.append(i)
    if len(ans_idx):
        ans_idx = torch.tensor(ans_idx).to(ctx)
    A_dense_select = sparse_matrix_to_dense(A.index_select(dim, ans_idx))

    assert A_compact_dense.shape == A_dense_select.shape
    assert torch.allclose(A_compact_dense, A_dense_select)
    assert torch.allclose(ans_idx, ret_id)