test_grid.py 1.03 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
from itertools import product

import pytest
Matthias Fey's avatar
Matthias Fey committed
4
import torch
rusty1s's avatar
rusty1s committed
5
from torch_cluster import grid_cluster
Matthias Fey's avatar
Matthias Fey committed
6
from torch_cluster.testing import devices, dtypes, tensor
rusty1s's avatar
rusty1s committed
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28

tests = [{
    'pos': [2, 6],
    'size': [5],
    'cluster': [0, 0],
}, {
    'pos': [2, 6],
    'size': [5],
    'start': [0],
    'cluster': [0, 1],
}, {
    'pos': [[0, 0], [11, 9], [2, 8], [2, 2], [8, 3]],
    'size': [5, 5],
    'cluster': [0, 5, 3, 0, 1],
}, {
    'pos': [[0, 0], [11, 9], [2, 8], [2, 2], [8, 3]],
    'size': [5, 5],
    'end': [19, 19],
    'cluster': [0, 6, 4, 0, 1],
}]


rusty1s's avatar
rusty1s committed
29
@pytest.mark.parametrize('test,dtype,device', product(tests, dtypes, devices))
rusty1s's avatar
typo  
rusty1s committed
30
def test_grid_cluster(test, dtype, device):
Matthias Fey's avatar
Matthias Fey committed
31
32
33
    if dtype == torch.bfloat16 and device == torch.device('cuda:0'):
        return

rusty1s's avatar
rusty1s committed
34
35
36
37
    pos = tensor(test['pos'], dtype, device)
    size = tensor(test['size'], dtype, device)
    start = tensor(test.get('start'), dtype, device)
    end = tensor(test.get('end'), dtype, device)
rusty1s's avatar
rusty1s committed
38
39

    cluster = grid_cluster(pos, size, start, end)
rusty1s's avatar
rusty1s committed
40
    assert cluster.tolist() == test['cluster']