test_cat.py 1.75 KB
Newer Older
rusty1s's avatar
cat  
rusty1s committed
1
2
3
4
5
import pytest
import torch
from torch_sparse.tensor import SparseTensor
from torch_sparse.cat import cat

rusty1s's avatar
rusty1s committed
6
from .utils import devices, tensor
rusty1s's avatar
cat  
rusty1s committed
7
8


rusty1s's avatar
rusty1s committed
9
10
@pytest.mark.parametrize('device', devices)
def test_cat(device):
rusty1s's avatar
rusty1s committed
11
12
    row, col = tensor([[0, 0, 1], [0, 1, 2]], torch.long, device)
    mat1 = SparseTensor(row=row, col=col)
rusty1s's avatar
rusty1s committed
13
    mat1.fill_cache_()
rusty1s's avatar
cat  
rusty1s committed
14

rusty1s's avatar
rusty1s committed
15
16
    row, col = tensor([[0, 0, 1, 2], [0, 1, 1, 0]], torch.long, device)
    mat2 = SparseTensor(row=row, col=col)
rusty1s's avatar
rusty1s committed
17
18
19
20
21
    mat2.fill_cache_()

    out = cat([mat1, mat2], dim=0)
    assert out.to_dense().tolist() == [[1, 1, 0], [0, 0, 1], [1, 1, 0],
                                       [0, 1, 0], [1, 0, 0]]
rusty1s's avatar
rusty1s committed
22
    assert out.storage.has_row()
rusty1s's avatar
rusty1s committed
23
    assert out.storage.has_rowptr()
rusty1s's avatar
rusty1s committed
24
25
    assert len(out.storage.cached_keys()) == 1
    assert out.storage.has_rowcount()
rusty1s's avatar
rusty1s committed
26
27
28
29

    out = cat([mat1, mat2], dim=1)
    assert out.to_dense().tolist() == [[1, 1, 0, 1, 1], [0, 0, 1, 0, 1],
                                       [0, 0, 0, 1, 0]]
rusty1s's avatar
rusty1s committed
30
31
    assert out.storage.has_row()
    assert not out.storage.has_rowptr()
rusty1s's avatar
rusty1s committed
32
33
34
35
36
37
38
39
    assert len(out.storage.cached_keys()) == 2
    assert out.storage.has_colcount()
    assert out.storage.has_colptr()

    out = cat([mat1, mat2], dim=(0, 1))
    assert out.to_dense().tolist() == [[1, 1, 0, 0, 0], [0, 0, 1, 0, 0],
                                       [0, 0, 0, 1, 1], [0, 0, 0, 0, 1],
                                       [0, 0, 0, 1, 0]]
rusty1s's avatar
rusty1s committed
40
41
42
    assert out.storage.has_row()
    assert out.storage.has_rowptr()
    assert len(out.storage.cached_keys()) == 5
rusty1s's avatar
cat  
rusty1s committed
43

rusty1s's avatar
rusty1s committed
44
45
46
    mat1.set_value_(torch.randn((mat1.nnz(), 4), device=device))
    out = cat([mat1, mat1], dim=-1)
    assert out.storage.value.size() == (mat1.nnz(), 8)
rusty1s's avatar
rusty1s committed
47
48
49
    assert out.storage.has_row()
    assert out.storage.has_rowptr()
    assert len(out.storage.cached_keys()) == 5