test_cat.py 508 Bytes
Newer Older
rusty1s's avatar
cat  
rusty1s committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
from itertools import product

import pytest
import torch
from torch_sparse.tensor import SparseTensor
from torch_sparse.cat import cat

from .utils import dtypes, devices, tensor


@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_cat(dtype, device):
    index = tensor([[0, 0, 1], [0, 1, 2]], torch.long, device)
    mat1 = SparseTensor(index)

    index = tensor([[0, 0, 1, 2], [0, 1, 1, 0]], torch.long, device)
    mat2 = SparseTensor(index)

    cat([mat1, mat2], dim=(0, 1))