Commit 4e7de164 authored by Alexander Liao's avatar Alexander Liao
Browse files

remove pickle dependency

parent 0d663771
...@@ -3,7 +3,6 @@ from itertools import product ...@@ -3,7 +3,6 @@ from itertools import product
import pytest import pytest
import torch import torch
from torch_cluster import knn, knn_graph from torch_cluster import knn, knn_graph
import pickle
from .utils import grad_dtypes, devices, tensor from .utils import grad_dtypes, devices, tensor
...@@ -61,29 +60,41 @@ def test_knn_graph(dtype, device): ...@@ -61,29 +60,41 @@ def test_knn_graph(dtype, device):
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices)) @pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_knn_graph_large(dtype, device): def test_knn_graph_large(dtype, device):
d = pickle.load(open("test/knn_test_large.pkl", "rb")) x = torch.tensor([[-1.0320, 0.2380, 0.2380],
x = d['x'].to(device) [-1.3050, -0.0930, 0.6420],
k = d['k'] [-0.3190, -0.0410, 1.2150],
truth = d['edges'] [1.1400, -0.5390, -0.3140],
[0.8410, 0.8290, 0.6090],
row, col = knn_graph(x, k=k, flow='source_to_target', [-1.4380, -0.2420, -0.3260],
batch=None, n_threads=24) [-2.2980, 0.7160, 0.9320],
[-1.3680, -0.4390, 0.1380],
[-0.6710, 0.6060, 1.1800],
[0.3950, -0.0790, 1.4920]],).to(device)
k = 3
truth = set({(4, 8), (2, 8), (9, 8), (8, 0), (0, 7), (2, 1), (9, 4),
(5, 1), (4, 9), (2, 9), (8, 1), (1, 5), (5, 0), (3, 2),
(8, 2), (7, 1), (6, 0), (3, 9), (0, 5), (7, 5), (4, 2),
(1, 0), (0, 1), (7, 0), (6, 8), (9, 2), (6, 1), (5, 7),
(1, 7), (3, 4)})
row, col = knn_graph(x, k=k, flow='target_to_source',
batch=None, n_threads=24, loop=False)
edges = set([(i, j) for (i, j) in zip(list(row.cpu().numpy()), edges = set([(i, j) for (i, j) in zip(list(row.cpu().numpy()),
list(col.cpu().numpy()))]) list(col.cpu().numpy()))])
assert(truth == edges) assert(truth == edges)
row, col = knn_graph(x, k=k, flow='source_to_target', row, col = knn_graph(x, k=k, flow='target_to_source',
batch=None, n_threads=12) batch=None, n_threads=12, loop=False)
edges = set([(i, j) for (i, j) in zip(list(row.cpu().numpy()), edges = set([(i, j) for (i, j) in zip(list(row.cpu().numpy()),
list(col.cpu().numpy()))]) list(col.cpu().numpy()))])
assert(truth == edges) assert(truth == edges)
row, col = knn_graph(x, k=k, flow='source_to_target', row, col = knn_graph(x, k=k, flow='target_to_source',
batch=None, n_threads=1) batch=None, n_threads=1, loop=False)
edges = set([(i, j) for (i, j) in zip(list(row.cpu().numpy()), edges = set([(i, j) for (i, j) in zip(list(row.cpu().numpy()),
list(col.cpu().numpy()))]) list(col.cpu().numpy()))])
......
...@@ -4,7 +4,14 @@ import pytest ...@@ -4,7 +4,14 @@ import pytest
import torch import torch
from torch_cluster import radius, radius_graph from torch_cluster import radius, radius_graph
from .utils import grad_dtypes, devices, tensor from .utils import grad_dtypes, devices, tensor
import pickle import scipy.spatial
@torch.jit.script
def sample(col: torch.Tensor, count: int) -> torch.Tensor:
if col.size(0) > count:
col = col[torch.randperm(col.size(0), dtype=torch.long)][:count]
return col
def coalesce(index): def coalesce(index):
...@@ -594,10 +601,22 @@ def test_radius_graph_ndim(dtype, device): ...@@ -594,10 +601,22 @@ def test_radius_graph_ndim(dtype, device):
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices)) @pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_radius_graph_large(dtype, device): def test_radius_graph_large(dtype, device):
d = pickle.load(open("test/radius_test_large.pkl", "rb")) x = torch.randn((8192*4, 6))
x = d['x'].to(device) r = 0.5
r = d['r']
truth = d['edges'] tree = scipy.spatial.cKDTree(x.detach().cpu().numpy())
col = tree.query_ball_point(x.detach().cpu().numpy(), r)
col = [torch.tensor(c, dtype=torch.long) for c in col]
col = [sample(c, 32) for c in col]
row = [torch.full_like(c, i) for i, c in enumerate(col)]
row, col = torch.cat(row, dim=0), torch.cat(col, dim=0)
mask = col < int(tree.n)
row_truth, col_truth = torch.stack([row[mask], col[mask]], dim=0)
mask = row_truth != col_truth
row_truth, col_truth = row_truth[mask], col_truth[mask]
truth = (set([(i, j) for (i, j) in zip(list(row_truth.cpu().numpy()),
list(col_truth.cpu().numpy()))]))
row, col = radius_graph(x, r=r, flow='source_to_target', row, col = radius_graph(x, r=r, flow='source_to_target',
batch=None, n_threads=24) batch=None, n_threads=24)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment