Commit 932b96e2 authored by rusty1s's avatar rusty1s
Browse files

generic test

parent 385e1bca
......@@ -5,5 +5,6 @@ void THCCCharGreedy(THCudaLongTensor *cluster, THCudaLongTensor *row, THCudaLo
void THCCShortGreedy(THCudaLongTensor *cluster, THCudaLongTensor *row, THCudaLongTensor *col, THCudaShortTensor *weight);
void THCCIntGreedy(THCudaLongTensor *cluster, THCudaLongTensor *row, THCudaLongTensor *col, THCudaIntTensor *weight);
void THCCLongGreedy(THCudaLongTensor *cluster, THCudaLongTensor *row, THCudaLongTensor *col, THCudaLongTensor *weight);
void THCCHalfGreedy(THCudaLongTensor *cluster, THCudaLongTensor *row, THCudaLongTensor *col, THCudaHalfTensor *weight);
void THCCFloatGreedy(THCudaLongTensor *cluster, THCudaLongTensor *row, THCudaLongTensor *col, THCudaTensor *weight);
void THCCDoubleGreedy(THCudaLongTensor *cluster, THCudaLongTensor *row, THCudaLongTensor *col, THCudaDoubleTensor *weight);
......@@ -3,5 +3,6 @@ void THCCCharGrid(THCudaLongTensor *cluster, THCudaCharTensor *pos, THCuda
void THCCShortGrid(THCudaLongTensor *cluster, THCudaShortTensor *pos, THCudaShortTensor *size, THCudaLongTensor *count);
void THCCIntGrid(THCudaLongTensor *cluster, THCudaIntTensor *pos, THCudaIntTensor *size, THCudaLongTensor *count);
void THCCLongGrid(THCudaLongTensor *cluster, THCudaLongTensor *pos, THCudaLongTensor *size, THCudaLongTensor *count);
void THCCHalfGrid(THCudaLongTensor *cluster, THCudaHalfTensor *pos, THCudaHalfTensor *size, THCudaLongTensor *count);
void THCCFloatGrid(THCudaLongTensor *cluster, THCudaTensor *pos, THCudaTensor *size, THCudaLongTensor *count);
void THCCDoubleGrid(THCudaLongTensor *cluster, THCudaDoubleTensor *pos, THCudaDoubleTensor *size, THCudaLongTensor *count);
cpu_tensors = [
'ByteTensor', 'CharTensor', 'ShortTensor', 'IntTensor', 'LongTensor',
'FloatTensor', 'DoubleTensor'
]
cuda_tensors = ['cuda.{}'.format(t) for t in cpu_tensors + ['HalfTensor']]
from itertools import product
import pytest
import torch
import numpy as np
from torch_cluster import graclus_cluster
from .tensor import cpu_tensors, cuda_tensors
tests = [{
'row': [0, 0, 1, 1, 1, 2, 2, 2, 3, 3],
'col': [1, 2, 0, 2, 3, 0, 1, 3, 1, 2],
'weight': None,
}, {
'row': [0, 0, 1, 1, 1, 2, 2, 2, 3, 3],
'col': [1, 2, 0, 2, 3, 0, 1, 3, 1, 2],
'weight': [1, 2, 1, 3, 2, 2, 3, 1, 2, 1],
}]
def assert_correct_graclus(row, col, cluster):
row, col, cluster = row.numpy(), col.numpy(), cluster.numpy()
......@@ -15,16 +30,21 @@ def assert_correct_graclus(row, col, cluster):
# Corresponding clusters must be adjacent.
for n in range(cluster.shape[0]):
assert (cluster[col[row == n]] == cluster[n]).max() == 1
x = cluster[col[row == n]] == cluster[n] # Neighbors with same cluster
y = cluster == cluster[n] # Nodes with same cluster
y[n] = 0 # Do not look at cluster of node `n`.
assert x.sum() == y.sum()
def test_graclus_cluster_cpu():
row = torch.LongTensor([0, 0, 1, 1, 1, 2, 2, 2, 3, 3])
col = torch.LongTensor([1, 2, 0, 2, 3, 0, 1, 3, 1, 2])
weight = torch.Tensor([1, 2, 1, 3, 2, 2, 3, 1, 2, 1])
@pytest.mark.parametrize('tensor,i', product(cpu_tensors, range(len(tests))))
def test_graclus_cluster_cpu(tensor, i):
data = tests[i]
cluster = graclus_cluster(row, col)
assert_correct_graclus(row, col, cluster)
row = torch.LongTensor(data['row'])
col = torch.LongTensor(data['col'])
weight = data['weight']
weight = weight if weight is None else getattr(torch, tensor)(weight)
cluster = graclus_cluster(row, col, weight)
assert_correct_graclus(row, col, cluster)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment