Commit c7f73ca2 authored by rusty1s's avatar rusty1s
Browse files

added grid tests

parent 4576030c
...@@ -10,7 +10,6 @@ from .tensor import tensors ...@@ -10,7 +10,6 @@ from .tensor import tensors
tests = [{ tests = [{
'row': [0, 0, 1, 1, 1, 2, 2, 2, 3, 3], 'row': [0, 0, 1, 1, 1, 2, 2, 2, 3, 3],
'col': [1, 2, 0, 2, 3, 0, 1, 3, 1, 2], 'col': [1, 2, 0, 2, 3, 0, 1, 3, 1, 2],
'weight': None,
}, { }, {
'row': [0, 0, 1, 1, 1, 2, 2, 2, 3, 3], 'row': [0, 0, 1, 1, 1, 2, 2, 2, 3, 3],
'col': [1, 2, 0, 2, 3, 0, 1, 3, 1, 2], 'col': [1, 2, 0, 2, 3, 0, 1, 3, 1, 2],
...@@ -47,7 +46,7 @@ def test_graclus_cluster_cpu(tensor, i): ...@@ -47,7 +46,7 @@ def test_graclus_cluster_cpu(tensor, i):
row = torch.LongTensor(data['row']) row = torch.LongTensor(data['row'])
col = torch.LongTensor(data['col']) col = torch.LongTensor(data['col'])
weight = data['weight'] weight = data.get('weight')
weight = weight if weight is None else getattr(torch, tensor)(weight) weight = weight if weight is None else getattr(torch, tensor)(weight)
cluster = graclus_cluster(row, col, weight) cluster = graclus_cluster(row, col, weight)
...@@ -62,7 +61,7 @@ def test_graclus_cluster_gpu(tensor, i): # pragma: no cover ...@@ -62,7 +61,7 @@ def test_graclus_cluster_gpu(tensor, i): # pragma: no cover
row = torch.cuda.LongTensor(data['row']) row = torch.cuda.LongTensor(data['row'])
col = torch.cuda.LongTensor(data['col']) col = torch.cuda.LongTensor(data['col'])
weight = data['weight'] weight = data.get('weight')
weight = weight if weight is None else getattr(torch.cuda, tensor)(weight) weight = weight if weight is None else getattr(torch.cuda, tensor)(weight)
cluster = graclus_cluster(row, col, weight) cluster = graclus_cluster(row, col, weight)
......
from itertools import product
import pytest
import torch
from torch_cluster import grid_cluster
from .tensor import tensors
tests = [{
'pos': [2, 6],
'size': [5],
'cluster': [0, 0],
}, {
'pos': [2, 6],
'size': [5],
'start': [0],
'cluster': [0, 1],
}, {
'pos': [[0, 0], [11, 9], [2, 8], [2, 2], [8, 3]],
'size': [5, 5],
'cluster': [0, 5, 3, 0, 1],
}, {
'pos': [[0, 0], [11, 9], [2, 8], [2, 2], [8, 3]],
'size': [5, 5],
'end': [19, 19],
'cluster': [0, 6, 4, 0, 1],
}]
@pytest.mark.parametrize('tensor,i', product(tensors, range(len(tests))))
def test_grid_cluster_cpu(tensor, i):
data = tests[i]
pos = getattr(torch, tensor)(data['pos'])
size = getattr(torch, tensor)(data['size'])
start = data.get('start')
start = start if start is None else getattr(torch, tensor)(start)
end = data.get('end')
end = end if end is None else getattr(torch, tensor)(end)
cluster = grid_cluster(pos, size, start, end)
assert cluster.tolist() == data['cluster']
@pytest.mark.skipif(not torch.cuda.is_available(), reason='no CUDA')
@pytest.mark.parametrize('tensor,i', product(tensors, range(len(tests))))
def test_grid_cluster_gpu(tensor, i):
data = tests[i]
pos = getattr(torch.cuda, tensor)(data['pos'])
size = getattr(torch.cuda, tensor)(data['size'])
start = data.get('start')
start = start if start is None else getattr(torch.cuda, tensor)(start)
end = data.get('end')
end = end if end is None else getattr(torch.cuda, tensor)(end)
cluster = grid_cluster(pos, size, start, end)
assert cluster.tolist() == data['cluster']
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment