Unverified Commit 29cd22bf authored by Matthias Fey's avatar Matthias Fey Committed by GitHub
Browse files

Move `torch.jit.script` check to test (#194)

* update

* update

* update
parent 89b74f0a
cmake_minimum_required(VERSION 3.0) cmake_minimum_required(VERSION 3.0)
project(torchcluster) project(torchcluster)
set(CMAKE_CXX_STANDARD 14) set(CMAKE_CXX_STANDARD 14)
set(TORCHCLUSTER_VERSION 1.6.2) set(TORCHCLUSTER_VERSION 1.6.3)
option(WITH_CUDA "Enable CUDA support" OFF) option(WITH_CUDA "Enable CUDA support" OFF)
option(WITH_PYTHON "Link to Python when building" ON) option(WITH_PYTHON "Link to Python when building" ON)
......
package: package:
name: pytorch-cluster name: pytorch-cluster
version: 1.6.2 version: 1.6.3
source: source:
path: ../.. path: ../..
......
...@@ -11,7 +11,7 @@ from torch.__config__ import parallel_info ...@@ -11,7 +11,7 @@ from torch.__config__ import parallel_info
from torch.utils.cpp_extension import (CUDA_HOME, BuildExtension, CppExtension, from torch.utils.cpp_extension import (CUDA_HOME, BuildExtension, CppExtension,
CUDAExtension) CUDAExtension)
__version__ = '1.6.2' __version__ = '1.6.3'
URL = 'https://github.com/rusty1s/pytorch_cluster' URL = 'https://github.com/rusty1s/pytorch_cluster'
WITH_CUDA = False WITH_CUDA = False
......
...@@ -50,3 +50,7 @@ def test_graclus_cluster(test, dtype, device): ...@@ -50,3 +50,7 @@ def test_graclus_cluster(test, dtype, device):
cluster = graclus_cluster(row, col, weight) cluster = graclus_cluster(row, col, weight)
assert_correct(row, col, cluster) assert_correct(row, col, cluster)
jit = torch.jit.script(graclus_cluster)
cluster = jit(row, col, weight)
assert_correct(row, col, cluster)
...@@ -38,3 +38,6 @@ def test_grid_cluster(test, dtype, device): ...@@ -38,3 +38,6 @@ def test_grid_cluster(test, dtype, device):
cluster = grid_cluster(pos, size, start, end) cluster = grid_cluster(pos, size, start, end)
assert cluster.tolist() == test['cluster'] assert cluster.tolist() == test['cluster']
jit = torch.jit.script(grid_cluster)
assert torch.equal(jit(pos, size, start, end), cluster)
...@@ -34,6 +34,10 @@ def test_knn(dtype, device): ...@@ -34,6 +34,10 @@ def test_knn(dtype, device):
edge_index = knn(x, y, 2) edge_index = knn(x, y, 2)
assert to_set(edge_index) == set([(0, 2), (0, 3), (1, 0), (1, 1)]) assert to_set(edge_index) == set([(0, 2), (0, 3), (1, 0), (1, 1)])
jit = torch.jit.script(knn)
edge_index = jit(x, y, 2)
assert to_set(edge_index) == set([(0, 2), (0, 3), (1, 0), (1, 1)])
edge_index = knn(x, y, 2, batch_x, batch_y) edge_index = knn(x, y, 2, batch_x, batch_y)
assert to_set(edge_index) == set([(0, 2), (0, 3), (1, 4), (1, 5)]) assert to_set(edge_index) == set([(0, 2), (0, 3), (1, 4), (1, 5)])
...@@ -65,6 +69,11 @@ def test_knn_graph(dtype, device): ...@@ -65,6 +69,11 @@ def test_knn_graph(dtype, device):
assert to_set(edge_index) == set([(1, 0), (3, 0), (0, 1), (2, 1), (1, 2), assert to_set(edge_index) == set([(1, 0), (3, 0), (0, 1), (2, 1), (1, 2),
(3, 2), (0, 3), (2, 3)]) (3, 2), (0, 3), (2, 3)])
jit = torch.jit.script(knn_graph)
edge_index = jit(x, k=2, flow='source_to_target')
assert to_set(edge_index) == set([(1, 0), (3, 0), (0, 1), (2, 1), (1, 2),
(3, 2), (0, 3), (2, 3)])
@pytest.mark.parametrize('dtype,device', product([torch.float], devices)) @pytest.mark.parametrize('dtype,device', product([torch.float], devices))
def test_knn_graph_large(dtype, device): def test_knn_graph_large(dtype, device):
......
...@@ -35,6 +35,11 @@ def test_radius(dtype, device): ...@@ -35,6 +35,11 @@ def test_radius(dtype, device):
assert to_set(edge_index) == set([(0, 0), (0, 1), (0, 2), (0, 3), (1, 1), assert to_set(edge_index) == set([(0, 0), (0, 1), (0, 2), (0, 3), (1, 1),
(1, 2), (1, 5), (1, 6)]) (1, 2), (1, 5), (1, 6)])
jit = torch.jit.script(radius)
edge_index = jit(x, y, 2, max_num_neighbors=4)
assert to_set(edge_index) == set([(0, 0), (0, 1), (0, 2), (0, 3), (1, 1),
(1, 2), (1, 5), (1, 6)])
edge_index = radius(x, y, 2, batch_x, batch_y, max_num_neighbors=4) edge_index = radius(x, y, 2, batch_x, batch_y, max_num_neighbors=4)
assert to_set(edge_index) == set([(0, 0), (0, 1), (0, 2), (0, 3), (1, 5), assert to_set(edge_index) == set([(0, 0), (0, 1), (0, 2), (0, 3), (1, 5),
(1, 6)]) (1, 6)])
...@@ -64,12 +69,20 @@ def test_radius_graph(dtype, device): ...@@ -64,12 +69,20 @@ def test_radius_graph(dtype, device):
assert to_set(edge_index) == set([(1, 0), (3, 0), (0, 1), (2, 1), (1, 2), assert to_set(edge_index) == set([(1, 0), (3, 0), (0, 1), (2, 1), (1, 2),
(3, 2), (0, 3), (2, 3)]) (3, 2), (0, 3), (2, 3)])
jit = torch.jit.script(radius_graph)
edge_index = jit(x, r=2.5, flow='source_to_target')
assert to_set(edge_index) == set([(1, 0), (3, 0), (0, 1), (2, 1), (1, 2),
(3, 2), (0, 3), (2, 3)])
@pytest.mark.parametrize('dtype,device', product([torch.float], devices)) @pytest.mark.parametrize('dtype,device', product([torch.float], devices))
def test_radius_graph_large(dtype, device): def test_radius_graph_large(dtype, device):
x = torch.randn(1000, 3, dtype=dtype, device=device) x = torch.randn(1000, 3, dtype=dtype, device=device)
edge_index = radius_graph(x, r=0.5, flow='target_to_source', loop=True, edge_index = radius_graph(x,
r=0.5,
flow='target_to_source',
loop=True,
max_num_neighbors=2000) max_num_neighbors=2000)
tree = scipy.spatial.cKDTree(x.cpu().numpy()) tree = scipy.spatial.cKDTree(x.cpu().numpy())
......
...@@ -31,6 +31,9 @@ def test_rw_small(device): ...@@ -31,6 +31,9 @@ def test_rw_small(device):
out = random_walk(row, col, start, walk_length, num_nodes=3) out = random_walk(row, col, start, walk_length, num_nodes=3)
assert out.tolist() == [[0, 1, 0, 1, 0], [1, 0, 1, 0, 1], [2, 2, 2, 2, 2]] assert out.tolist() == [[0, 1, 0, 1, 0], [1, 0, 1, 0, 1], [2, 2, 2, 2, 2]]
jit = torch.jit.script(random_walk)
assert torch.equal(jit(row, col, start, walk_length, num_nodes=3), out)
@pytest.mark.parametrize('device', devices) @pytest.mark.parametrize('device', devices)
def test_rw_large_with_edge_indices(device): def test_rw_large_with_edge_indices(device):
......
...@@ -3,7 +3,7 @@ import os.path as osp ...@@ -3,7 +3,7 @@ import os.path as osp
import torch import torch
__version__ = '1.6.2' __version__ = '1.6.3'
for library in [ for library in [
'_version', '_grid', '_graclus', '_fps', '_rw', '_sampler', '_nearest', '_version', '_grid', '_graclus', '_fps', '_rw', '_sampler', '_nearest',
......
...@@ -3,10 +3,12 @@ from typing import Optional ...@@ -3,10 +3,12 @@ from typing import Optional
import torch import torch
@torch.jit.script def graclus_cluster(
def graclus_cluster(row: torch.Tensor, col: torch.Tensor, row: torch.Tensor,
col: torch.Tensor,
weight: Optional[torch.Tensor] = None, weight: Optional[torch.Tensor] = None,
num_nodes: Optional[int] = None) -> torch.Tensor: num_nodes: Optional[int] = None,
) -> torch.Tensor:
"""A greedy clustering algorithm of picking an unmarked vertex and matching """A greedy clustering algorithm of picking an unmarked vertex and matching
it with one its unmarked neighbors (that maximizes its edge weight). it with one its unmarked neighbors (that maximizes its edge weight).
......
...@@ -3,10 +3,12 @@ from typing import Optional ...@@ -3,10 +3,12 @@ from typing import Optional
import torch import torch
@torch.jit.script def grid_cluster(
def grid_cluster(pos: torch.Tensor, size: torch.Tensor, pos: torch.Tensor,
size: torch.Tensor,
start: Optional[torch.Tensor] = None, start: Optional[torch.Tensor] = None,
end: Optional[torch.Tensor] = None) -> torch.Tensor: end: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""A clustering algorithm, which overlays a regular grid of user-defined """A clustering algorithm, which overlays a regular grid of user-defined
size over a point cloud and clusters all points within a voxel. size over a point cloud and clusters all points within a voxel.
......
...@@ -3,7 +3,6 @@ from typing import Optional ...@@ -3,7 +3,6 @@ from typing import Optional
import torch import torch
@torch.jit.script
def knn( def knn(
x: torch.Tensor, x: torch.Tensor,
y: torch.Tensor, y: torch.Tensor,
...@@ -83,7 +82,6 @@ def knn( ...@@ -83,7 +82,6 @@ def knn(
num_workers) num_workers)
@torch.jit.script
def knn_graph( def knn_graph(
x: torch.Tensor, x: torch.Tensor,
k: int, k: int,
......
...@@ -3,7 +3,6 @@ from typing import Optional ...@@ -3,7 +3,6 @@ from typing import Optional
import torch import torch
@torch.jit.script
def radius( def radius(
x: torch.Tensor, x: torch.Tensor,
y: torch.Tensor, y: torch.Tensor,
...@@ -84,7 +83,6 @@ def radius( ...@@ -84,7 +83,6 @@ def radius(
max_num_neighbors, num_workers) max_num_neighbors, num_workers)
@torch.jit.script
def radius_graph( def radius_graph(
x: torch.Tensor, x: torch.Tensor,
r: float, r: float,
......
...@@ -4,7 +4,6 @@ import torch ...@@ -4,7 +4,6 @@ import torch
from torch import Tensor from torch import Tensor
@torch.jit.script
def random_walk( def random_walk(
row: Tensor, row: Tensor,
col: Tensor, col: Tensor,
...@@ -55,8 +54,7 @@ def random_walk( ...@@ -55,8 +54,7 @@ def random_walk(
torch.cumsum(deg, 0, out=rowptr[1:]) torch.cumsum(deg, 0, out=rowptr[1:])
node_seq, edge_seq = torch.ops.torch_cluster.random_walk( node_seq, edge_seq = torch.ops.torch_cluster.random_walk(
rowptr, col, start, walk_length, p, q, rowptr, col, start, walk_length, p, q)
)
if return_edge_indices: if return_edge_indices:
return node_seq, edge_seq return node_seq, edge_seq
......
import torch import torch
@torch.jit.script
def neighbor_sampler(start: torch.Tensor, rowptr: torch.Tensor, size: float): def neighbor_sampler(start: torch.Tensor, rowptr: torch.Tensor, size: float):
assert not start.is_cuda assert not start.is_cuda
......
import torch import torch
WITH_PTR_LIST = hasattr(torch.ops.torch_cluster, 'fps_ptr_list') try:
WITH_PTR_LIST = hasattr(torch.ops.torch_cluster, 'fps_ptr_list')
except Exception:
WITH_PTR_LIST = False
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment