test_diag.py 1.03 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
8
9
10
from itertools import product

import pytest
import torch
from torch_sparse.tensor import SparseTensor

from .utils import dtypes, devices, tensor


@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
rusty1s's avatar
rusty1s committed
11
def test_remove_diag(dtype, device):
rusty1s's avatar
rusty1s committed
12
13
14
15
16
17
    index = tensor([[0, 0, 1, 2], [0, 1, 2, 2]], torch.long, device)
    value = tensor([1, 2, 3, 4], dtype, device)
    mat = SparseTensor(index, value)
    mat.fill_cache_()

    mat = mat.remove_diag()
rusty1s's avatar
rusty1s committed
18
19
    assert mat.storage.index.tolist() == [[0, 1], [1, 2]]
    assert mat.storage.value.tolist() == [2, 3]
rusty1s's avatar
rusty1s committed
20
21
22
    assert len(mat.cached_keys()) == 2
    assert mat.storage.rowcount.tolist() == [1, 1, 0]
    assert mat.storage.colcount.tolist() == [0, 1, 1]
rusty1s's avatar
rusty1s committed
23
24
25
26
27
28
29
30
31
32

    mat = SparseTensor(index, value)
    mat.fill_cache_()

    mat = mat.remove_diag(k=1)
    assert mat.storage.index.tolist() == [[0, 2], [0, 2]]
    assert mat.storage.value.tolist() == [1, 4]
    assert len(mat.cached_keys()) == 2
    assert mat.storage.rowcount.tolist() == [1, 0, 1]
    assert mat.storage.colcount.tolist() == [1, 0, 1]