"tests/data/vscode:/vscode.git/clone" did not exist on "0a8e67f753045556da5665b6167e4b4f810ab589"
Commit b199bcb0 authored by rusty1s's avatar rusty1s
Browse files

fixed gpu tests

parent 9a1f7817
cpu_tensors = [
tensors = [
'ByteTensor', 'CharTensor', 'ShortTensor', 'IntTensor', 'LongTensor',
'FloatTensor', 'DoubleTensor'
]
gpu_tensors = ['cuda.{}'.format(t) for t in cpu_tensors + ['HalfTensor']]
......@@ -5,7 +5,7 @@ import torch
import numpy as np
from torch_cluster import graclus_cluster
from .tensor import cpu_tensors, gpu_tensors
from .tensor import tensors
tests = [{
'row': [0, 0, 1, 1, 1, 2, 2, 2, 3, 3],
......@@ -19,8 +19,8 @@ tests = [{
def assert_correct_graclus(row, col, cluster):
row, col, cluster = row.numpy(), col.numpy(), cluster.numpy()
n_nodes = cluster.shape[0]
row, col = row.cpu().numpy(), col.cpu().numpy()
cluster, n_nodes = cluster.cpu().numpy(), cluster.size(0)
# Every node was assigned a cluster.
assert cluster.min() >= 0
......@@ -40,7 +40,7 @@ def assert_correct_graclus(row, col, cluster):
assert x.sum() == y.sum()
@pytest.mark.parametrize('tensor,i', product(cpu_tensors, range(len(tests))))
@pytest.mark.parametrize('tensor,i', product(tensors, range(len(tests))))
def test_graclus_cluster_cpu(tensor, i):
data = tests[i]
......@@ -55,7 +55,7 @@ def test_graclus_cluster_cpu(tensor, i):
@pytest.mark.skipif(not torch.cuda.is_available(), reason='no CUDA')
@pytest.mark.parametrize('tensor,i', product(gpu_tensors, range(len(tests))))
@pytest.mark.parametrize('tensor,i', product(tensors, range(len(tests))))
def test_graclus_cluster_gpu(tensor, i):
data = tests[i]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment