Commit b199bcb0 authored by rusty1s's avatar rusty1s
Browse files

fixed gpu tests

parent 9a1f7817
cpu_tensors = [ tensors = [
'ByteTensor', 'CharTensor', 'ShortTensor', 'IntTensor', 'LongTensor', 'ByteTensor', 'CharTensor', 'ShortTensor', 'IntTensor', 'LongTensor',
'FloatTensor', 'DoubleTensor' 'FloatTensor', 'DoubleTensor'
] ]
gpu_tensors = ['cuda.{}'.format(t) for t in cpu_tensors + ['HalfTensor']]
...@@ -5,7 +5,7 @@ import torch ...@@ -5,7 +5,7 @@ import torch
import numpy as np import numpy as np
from torch_cluster import graclus_cluster from torch_cluster import graclus_cluster
from .tensor import cpu_tensors, gpu_tensors from .tensor import tensors
tests = [{ tests = [{
'row': [0, 0, 1, 1, 1, 2, 2, 2, 3, 3], 'row': [0, 0, 1, 1, 1, 2, 2, 2, 3, 3],
...@@ -19,8 +19,8 @@ tests = [{ ...@@ -19,8 +19,8 @@ tests = [{
def assert_correct_graclus(row, col, cluster): def assert_correct_graclus(row, col, cluster):
row, col, cluster = row.numpy(), col.numpy(), cluster.numpy() row, col = row.cpu().numpy(), col.cpu().numpy()
n_nodes = cluster.shape[0] cluster, n_nodes = cluster.cpu().numpy(), cluster.size(0)
# Every node was assigned a cluster. # Every node was assigned a cluster.
assert cluster.min() >= 0 assert cluster.min() >= 0
...@@ -40,7 +40,7 @@ def assert_correct_graclus(row, col, cluster): ...@@ -40,7 +40,7 @@ def assert_correct_graclus(row, col, cluster):
assert x.sum() == y.sum() assert x.sum() == y.sum()
@pytest.mark.parametrize('tensor,i', product(cpu_tensors, range(len(tests)))) @pytest.mark.parametrize('tensor,i', product(tensors, range(len(tests))))
def test_graclus_cluster_cpu(tensor, i): def test_graclus_cluster_cpu(tensor, i):
data = tests[i] data = tests[i]
...@@ -55,7 +55,7 @@ def test_graclus_cluster_cpu(tensor, i): ...@@ -55,7 +55,7 @@ def test_graclus_cluster_cpu(tensor, i):
@pytest.mark.skipif(not torch.cuda.is_available(), reason='no CUDA') @pytest.mark.skipif(not torch.cuda.is_available(), reason='no CUDA')
@pytest.mark.parametrize('tensor,i', product(gpu_tensors, range(len(tests)))) @pytest.mark.parametrize('tensor,i', product(tensors, range(len(tests))))
def test_graclus_cluster_gpu(tensor, i): def test_graclus_cluster_gpu(tensor, i):
data = tests[i] data = tests[i]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment