Commit 1edd387e authored by rusty1s's avatar rusty1s
Browse files

pytorch 0.4.0

parent b87eab0b
tensors = [
'ByteTensor', 'CharTensor', 'ShortTensor', 'IntTensor', 'LongTensor',
'FloatTensor', 'DoubleTensor'
]
...@@ -2,10 +2,9 @@ from itertools import product ...@@ -2,10 +2,9 @@ from itertools import product
import pytest import pytest
import torch import torch
import numpy as np
from torch_cluster import graclus_cluster from torch_cluster import graclus_cluster
from .tensor import tensors from .utils import dtypes, devices, tensor
tests = [{ tests = [{
'row': [0, 0, 1, 1, 1, 2, 2, 2, 3, 3], 'row': [0, 0, 1, 1, 1, 2, 2, 2, 3, 3],
...@@ -18,51 +17,34 @@ tests = [{ ...@@ -18,51 +17,34 @@ tests = [{
def assert_correct_graclus(row, col, cluster): def assert_correct_graclus(row, col, cluster):
row, col = row.cpu().numpy(), col.cpu().numpy() row, col, cluster = row.to('cpu'), col.to('cpu'), cluster.to('cpu')
cluster, n_nodes = cluster.cpu().numpy(), cluster.size(0) n = cluster.size(0)
# Every node was assigned a cluster. # Every node was assigned a cluster.
assert cluster.min() >= 0 assert cluster.min() >= 0
# There are no more than two nodes in each cluster. # There are no more than two nodes in each cluster.
_, count = np.unique(cluster, return_counts=True) _, index = torch.unique(cluster, return_inverse=True)
count = torch.zeros_like(cluster)
count.scatter_add_(0, index, torch.ones_like(cluster))
assert (count > 2).max() == 0 assert (count > 2).max() == 0
# Cluster value is minimal. # Cluster value is minimal.
assert (cluster <= np.arange(n_nodes, dtype=row.dtype)).sum() == n_nodes assert (cluster <= torch.arange(n, dtype=cluster.dtype)).sum() == n
# Corresponding clusters must be adjacent. # Corresponding clusters must be adjacent.
for n in range(cluster.shape[0]): for i in range(n):
x = cluster[col[row == n]] == cluster[n] # Neighbors with same cluster x = cluster[col[row == i]] == cluster[i] # Neighbors with same cluster
y = cluster == cluster[n] # Nodes with same cluster y = cluster == cluster[i] # Nodes with same cluster.
y[n] = 0 # Do not look at cluster of node `n`. y[i] = 0 # Do not look at cluster of `i`.
assert x.sum() == y.sum() assert x.sum() == y.sum()
@pytest.mark.parametrize('tensor,i', product(tensors, range(len(tests)))) @pytest.mark.parametrize('test,dtype,device', product(tests, dtypes, devices))
def test_graclus_cluster_cpu(tensor, i): def test_graclus_cluster_cpu(test, dtype, device):
data = tests[i] row = tensor(test['row'], torch.long, device)
col = tensor(test['col'], torch.long, device)
row = torch.LongTensor(data['row']) weight = tensor(test.get('weight'), dtype, device)
col = torch.LongTensor(data['col'])
weight = data.get('weight')
weight = weight if weight is None else getattr(torch, tensor)(weight)
cluster = graclus_cluster(row, col, weight)
assert_correct_graclus(row, col, cluster)
@pytest.mark.skipif(not torch.cuda.is_available(), reason='no CUDA')
@pytest.mark.parametrize('tensor,i', product(tensors, range(len(tests))))
def test_graclus_cluster_gpu(tensor, i): # pragma: no cover
data = tests[i]
row = torch.cuda.LongTensor(data['row'])
col = torch.cuda.LongTensor(data['col'])
weight = data.get('weight')
weight = weight if weight is None else getattr(torch.cuda, tensor)(weight)
cluster = graclus_cluster(row, col, weight) cluster = graclus_cluster(row, col, weight)
assert_correct_graclus(row, col, cluster) assert_correct_graclus(row, col, cluster)
from itertools import product from itertools import product
import pytest import pytest
import torch
from torch_cluster import grid_cluster from torch_cluster import grid_cluster
from .tensor import tensors from .utils import dtypes, devices, tensor
tests = [{ tests = [{
'pos': [2, 6], 'pos': [2, 6],
...@@ -27,36 +26,12 @@ tests = [{ ...@@ -27,36 +26,12 @@ tests = [{
}] }]
@pytest.mark.parametrize('tensor,i', product(tensors, range(len(tests)))) @pytest.mark.parametrize('test,dtype,device', product(tests, dtypes, devices))
def test_grid_cluster_cpu(tensor, i): def test_grid_cluster_cpu(test, dtype, device):
data = tests[i] pos = tensor(test['pos'], dtype, device)
size = tensor(test['size'], dtype, device)
pos = getattr(torch, tensor)(data['pos']) start = tensor(test.get('start'), dtype, device)
size = getattr(torch, tensor)(data['size']) end = tensor(test.get('end'), dtype, device)
start = data.get('start')
start = start if start is None else getattr(torch, tensor)(start)
end = data.get('end')
end = end if end is None else getattr(torch, tensor)(end)
cluster = grid_cluster(pos, size, start, end)
assert cluster.tolist() == data['cluster']
@pytest.mark.skipif(not torch.cuda.is_available(), reason='no CUDA')
@pytest.mark.parametrize('tensor,i', product(tensors, range(len(tests))))
def test_grid_cluster_gpu(tensor, i): # pragma: no cover
data = tests[i]
pos = getattr(torch.cuda, tensor)(data['pos'])
size = getattr(torch.cuda, tensor)(data['size'])
start = data.get('start')
start = start if start is None else getattr(torch.cuda, tensor)(start)
end = data.get('end')
end = end if end is None else getattr(torch.cuda, tensor)(end)
cluster = grid_cluster(pos, size, start, end) cluster = grid_cluster(pos, size, start, end)
assert cluster.tolist() == data['cluster'] assert cluster.tolist() == test['cluster']
import torch
from torch.testing import get_all_dtypes
dtypes = get_all_dtypes()
dtypes.remove(torch.half)
devices = [torch.device('cpu')]
if torch.cuda.is_available():
devices += [torch.device('cuda:{}'.format(torch.cuda.current_device()))]
def tensor(x, dtype, device):
return None if x is None else torch.tensor(x, dtype=dtype, device=device)
...@@ -29,7 +29,7 @@ def graclus_cluster(row, col, weight=None, num_nodes=None): ...@@ -29,7 +29,7 @@ def graclus_cluster(row, col, weight=None, num_nodes=None):
row, col = randperm_sort_row(row, col, num_nodes) row, col = randperm_sort_row(row, col, num_nodes)
row, col = remove_self_loops(row, col) row, col = remove_self_loops(row, col)
cluster = row.new(num_nodes) cluster = row.new_empty((num_nodes, ))
graclus(cluster, row, col, weight) graclus(cluster, row, col, weight)
return cluster return cluster
...@@ -3,7 +3,7 @@ from .._ext import ffi ...@@ -3,7 +3,7 @@ from .._ext import ffi
def get_func(name, is_cuda, tensor=None): def get_func(name, is_cuda, tensor=None):
prefix = 'THCC' if is_cuda else 'TH' prefix = 'THCC' if is_cuda else 'TH'
prefix += 'Tensor' if tensor is None else type(tensor).__name__ prefix += 'Tensor' if tensor is None else tensor.type().split('.')[-1]
return getattr(ffi, '{}_{}'.format(prefix, name)) return getattr(ffi, '{}_{}'.format(prefix, name))
......
...@@ -3,7 +3,7 @@ import torch ...@@ -3,7 +3,7 @@ import torch
def randperm(row, col): def randperm(row, col):
# Randomly reorder row and column indices. # Randomly reorder row and column indices.
edge_rid = torch.randperm(row.size(0)).type_as(row) edge_rid = torch.randperm(row.size(0))
return row[edge_rid], col[edge_rid] return row[edge_rid], col[edge_rid]
...@@ -16,7 +16,7 @@ def sort_row(row, col): ...@@ -16,7 +16,7 @@ def sort_row(row, col):
def randperm_sort_row(row, col, num_nodes): def randperm_sort_row(row, col, num_nodes):
# Randomly change row indices to new values. # Randomly change row indices to new values.
node_rid = torch.randperm(num_nodes).type_as(row) node_rid = torch.randperm(num_nodes)
row = node_rid[row] row = node_rid[row]
# Sort row and column indices row-wise. # Sort row and column indices row-wise.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment