test_diag.py 1.44 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
8
9
10
from itertools import product

import pytest
import torch
from torch_sparse.tensor import SparseTensor

from .utils import dtypes, devices, tensor


@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
rusty1s's avatar
rusty1s committed
11
def test_remove_diag(dtype, device):
rusty1s's avatar
rusty1s committed
12
    row, col = tensor([[0, 0, 1, 2], [0, 1, 2, 2]], torch.long, device)
rusty1s's avatar
rusty1s committed
13
    value = tensor([1, 2, 3, 4], dtype, device)
rusty1s's avatar
rusty1s committed
14
    mat = SparseTensor(row=row, col=col, value=value)
rusty1s's avatar
rusty1s committed
15
16
17
    mat.fill_cache_()

    mat = mat.remove_diag()
rusty1s's avatar
rusty1s committed
18
19
    assert mat.storage.row.tolist() == [0, 1]
    assert mat.storage.col.tolist() == [1, 2]
rusty1s's avatar
rusty1s committed
20
    assert mat.storage.value.tolist() == [2, 3]
rusty1s's avatar
rusty1s committed
21
22
23
    assert len(mat.cached_keys()) == 2
    assert mat.storage.rowcount.tolist() == [1, 1, 0]
    assert mat.storage.colcount.tolist() == [0, 1, 1]
rusty1s's avatar
rusty1s committed
24

rusty1s's avatar
rusty1s committed
25
    mat = SparseTensor(row=row, col=col, value=value)
rusty1s's avatar
rusty1s committed
26
27
28
    mat.fill_cache_()

    mat = mat.remove_diag(k=1)
rusty1s's avatar
rusty1s committed
29
30
    assert mat.storage.row.tolist() == [0, 2]
    assert mat.storage.col.tolist() == [0, 2]
rusty1s's avatar
rusty1s committed
31
32
33
34
    assert mat.storage.value.tolist() == [1, 4]
    assert len(mat.cached_keys()) == 2
    assert mat.storage.rowcount.tolist() == [1, 0, 1]
    assert mat.storage.colcount.tolist() == [1, 0, 1]
rusty1s's avatar
rusty1s committed
35
36
37
38


@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_set_diag(dtype, device):
rusty1s's avatar
rusty1s committed
39
    row, col = tensor([[0, 0, 9, 9], [0, 1, 0, 1]], torch.long, device)
rusty1s's avatar
rusty1s committed
40
    value = tensor([1, 2, 3, 4], dtype, device)
rusty1s's avatar
rusty1s committed
41
    mat = SparseTensor(row=row, col=col, value=value)
rusty1s's avatar
rusty1s committed
42
43

    k = -8
rusty1s's avatar
rusty1s committed
44
    mat = mat.set_diag(k)