test_tensor.py 1.57 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
from itertools import product

import pytest
import torch
from torch_sparse import SparseTensor

from .utils import grad_dtypes, devices


@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_getitem(dtype, device):
    mat = torch.randn(50, 40, dtype=dtype, device=device)
    mat = SparseTensor.from_dense(mat)

    idx1 = torch.randint(0, 50, (10, ), dtype=torch.long, device=device)
    idx2 = torch.randint(0, 40, (10, ), dtype=torch.long, device=device)

    assert mat[:10, :10].sizes() == [10, 10]
    assert mat[..., :10].sizes() == [50, 10]
    assert mat[idx1, idx2].sizes() == [10, 10]
    assert mat[idx1.tolist()].sizes() == [10, 40]
rusty1s's avatar
rusty1s committed
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39


@pytest.mark.parametrize('device', devices)
def test_to_symmetric(device):
    row = torch.tensor([0, 0, 0, 1, 1], device=device)
    col = torch.tensor([0, 1, 2, 0, 2], device=device)
    value = torch.arange(1, 6, device=device)
    mat = SparseTensor(row=row, col=col, value=value)
    assert not mat.is_symmetric()

    mat = mat.to_symmetric()

    assert mat.is_symmetric()
    assert mat.to_dense().tolist() == [
        [2, 6, 3],
        [6, 0, 5],
        [3, 5, 0],
    ]
rusty1s's avatar
rusty1s committed
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55


def test_equal():
    row = torch.tensor([0, 0, 0, 1, 1])
    col = torch.tensor([0, 1, 2, 0, 2])
    value = torch.arange(1, 6)
    matA = SparseTensor(row=row, col=col, value=value)
    matB = SparseTensor(row=row, col=col, value=value)
    col = torch.tensor([0, 1, 2, 0, 1])
    matC = SparseTensor(row=row, col=col, value=value)

    assert id(matA) != id(matB)
    assert matA == matB

    assert id(matA) != id(matC)
    assert matA != matC