"docker/release/Dockerfile" did not exist on "c344c110042df482608011a87e5b52fa183b370c"
test_diag.py 2.4 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
8
9
10
from itertools import product

import pytest
import torch
from torch_sparse.tensor import SparseTensor

from .utils import dtypes, devices, tensor


@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
rusty1s's avatar
rusty1s committed
11
def test_remove_diag(dtype, device):
rusty1s's avatar
rusty1s committed
12
    row, col = tensor([[0, 0, 1, 2], [0, 1, 2, 2]], torch.long, device)
rusty1s's avatar
rusty1s committed
13
    value = tensor([1, 2, 3, 4], dtype, device)
rusty1s's avatar
rusty1s committed
14
    mat = SparseTensor(row=row, col=col, value=value)
rusty1s's avatar
rusty1s committed
15
16
17
    mat.fill_cache_()

    mat = mat.remove_diag()
rusty1s's avatar
rusty1s committed
18
19
20
21
22
23
    assert mat.storage.row().tolist() == [0, 1]
    assert mat.storage.col().tolist() == [1, 2]
    assert mat.storage.value().tolist() == [2, 3]
    assert mat.storage.num_cached_keys() == 2
    assert mat.storage.rowcount().tolist() == [1, 1, 0]
    assert mat.storage.colcount().tolist() == [0, 1, 1]
rusty1s's avatar
rusty1s committed
24

rusty1s's avatar
rusty1s committed
25
    mat = SparseTensor(row=row, col=col, value=value)
rusty1s's avatar
rusty1s committed
26
27
28
    mat.fill_cache_()

    mat = mat.remove_diag(k=1)
rusty1s's avatar
rusty1s committed
29
30
31
32
33
34
    assert mat.storage.row().tolist() == [0, 2]
    assert mat.storage.col().tolist() == [0, 2]
    assert mat.storage.value().tolist() == [1, 4]
    assert mat.storage.num_cached_keys() == 2
    assert mat.storage.rowcount().tolist() == [1, 0, 1]
    assert mat.storage.colcount().tolist() == [1, 0, 1]
rusty1s's avatar
rusty1s committed
35
36
37
38


@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_set_diag(dtype, device):
rusty1s's avatar
rusty1s committed
39
    row, col = tensor([[0, 0, 9, 9], [0, 1, 0, 1]], torch.long, device)
rusty1s's avatar
rusty1s committed
40
    value = tensor([1, 2, 3, 4], dtype, device)
rusty1s's avatar
rusty1s committed
41
    mat = SparseTensor(row=row, col=col, value=value)
rusty1s's avatar
rusty1s committed
42

rusty1s's avatar
rusty1s committed
43
44
45
46
47
48
49
50
51
52
53
54
    mat = mat.set_diag(tensor([-8, -8], dtype, device), k=-1)
    mat = mat.set_diag(tensor([-8], dtype, device), k=1)


@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_fill_diag(dtype, device):
    row, col = tensor([[0, 0, 9, 9], [0, 1, 0, 1]], torch.long, device)
    value = tensor([1, 2, 3, 4], dtype, device)
    mat = SparseTensor(row=row, col=col, value=value)

    mat = mat.fill_diag(-8, k=-1)
    mat = mat.fill_diag(-8, k=1)
rusty1s's avatar
rusty1s committed
55
56
57
58
59
60
61
62
63
64
65
66


@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_get_diag(dtype, device):
    row, col = tensor([[0, 0, 1, 2], [0, 1, 2, 2]], torch.long, device)
    value = tensor([[1, 1], [2, 2], [3, 3], [4, 4]], dtype, device)
    mat = SparseTensor(row=row, col=col, value=value)
    assert mat.get_diag().tolist() == [[1, 1], [0, 0], [4, 4]]

    row, col = tensor([[0, 0, 1, 2], [0, 1, 2, 2]], torch.long, device)
    mat = SparseTensor(row=row, col=col)
    assert mat.get_diag().tolist() == [1, 0, 1]