Unverified Commit 29cd22bf authored by Matthias Fey's avatar Matthias Fey Committed by GitHub
Browse files

Move `torch.jit.script` check to test (#194)

* update

* update

* update
parent 89b74f0a
cmake_minimum_required(VERSION 3.0)
project(torchcluster)
set(CMAKE_CXX_STANDARD 14)
set(TORCHCLUSTER_VERSION 1.6.2)
set(TORCHCLUSTER_VERSION 1.6.3)
option(WITH_CUDA "Enable CUDA support" OFF)
option(WITH_PYTHON "Link to Python when building" ON)
......
package:
name: pytorch-cluster
version: 1.6.2
version: 1.6.3
source:
path: ../..
......
......@@ -11,7 +11,7 @@ from torch.__config__ import parallel_info
from torch.utils.cpp_extension import (CUDA_HOME, BuildExtension, CppExtension,
CUDAExtension)
__version__ = '1.6.2'
__version__ = '1.6.3'
URL = 'https://github.com/rusty1s/pytorch_cluster'
WITH_CUDA = False
......
......@@ -50,3 +50,7 @@ def test_graclus_cluster(test, dtype, device):
cluster = graclus_cluster(row, col, weight)
assert_correct(row, col, cluster)
jit = torch.jit.script(graclus_cluster)
cluster = jit(row, col, weight)
assert_correct(row, col, cluster)
......@@ -38,3 +38,6 @@ def test_grid_cluster(test, dtype, device):
cluster = grid_cluster(pos, size, start, end)
assert cluster.tolist() == test['cluster']
jit = torch.jit.script(grid_cluster)
assert torch.equal(jit(pos, size, start, end), cluster)
......@@ -34,6 +34,10 @@ def test_knn(dtype, device):
edge_index = knn(x, y, 2)
assert to_set(edge_index) == set([(0, 2), (0, 3), (1, 0), (1, 1)])
jit = torch.jit.script(knn)
edge_index = jit(x, y, 2)
assert to_set(edge_index) == set([(0, 2), (0, 3), (1, 0), (1, 1)])
edge_index = knn(x, y, 2, batch_x, batch_y)
assert to_set(edge_index) == set([(0, 2), (0, 3), (1, 4), (1, 5)])
......@@ -65,6 +69,11 @@ def test_knn_graph(dtype, device):
assert to_set(edge_index) == set([(1, 0), (3, 0), (0, 1), (2, 1), (1, 2),
(3, 2), (0, 3), (2, 3)])
jit = torch.jit.script(knn_graph)
edge_index = jit(x, k=2, flow='source_to_target')
assert to_set(edge_index) == set([(1, 0), (3, 0), (0, 1), (2, 1), (1, 2),
(3, 2), (0, 3), (2, 3)])
@pytest.mark.parametrize('dtype,device', product([torch.float], devices))
def test_knn_graph_large(dtype, device):
......
......@@ -35,6 +35,11 @@ def test_radius(dtype, device):
assert to_set(edge_index) == set([(0, 0), (0, 1), (0, 2), (0, 3), (1, 1),
(1, 2), (1, 5), (1, 6)])
jit = torch.jit.script(radius)
edge_index = jit(x, y, 2, max_num_neighbors=4)
assert to_set(edge_index) == set([(0, 0), (0, 1), (0, 2), (0, 3), (1, 1),
(1, 2), (1, 5), (1, 6)])
edge_index = radius(x, y, 2, batch_x, batch_y, max_num_neighbors=4)
assert to_set(edge_index) == set([(0, 0), (0, 1), (0, 2), (0, 3), (1, 5),
(1, 6)])
......@@ -64,12 +69,20 @@ def test_radius_graph(dtype, device):
assert to_set(edge_index) == set([(1, 0), (3, 0), (0, 1), (2, 1), (1, 2),
(3, 2), (0, 3), (2, 3)])
jit = torch.jit.script(radius_graph)
edge_index = jit(x, r=2.5, flow='source_to_target')
assert to_set(edge_index) == set([(1, 0), (3, 0), (0, 1), (2, 1), (1, 2),
(3, 2), (0, 3), (2, 3)])
@pytest.mark.parametrize('dtype,device', product([torch.float], devices))
def test_radius_graph_large(dtype, device):
x = torch.randn(1000, 3, dtype=dtype, device=device)
edge_index = radius_graph(x, r=0.5, flow='target_to_source', loop=True,
edge_index = radius_graph(x,
r=0.5,
flow='target_to_source',
loop=True,
max_num_neighbors=2000)
tree = scipy.spatial.cKDTree(x.cpu().numpy())
......
......@@ -31,6 +31,9 @@ def test_rw_small(device):
out = random_walk(row, col, start, walk_length, num_nodes=3)
assert out.tolist() == [[0, 1, 0, 1, 0], [1, 0, 1, 0, 1], [2, 2, 2, 2, 2]]
jit = torch.jit.script(random_walk)
assert torch.equal(jit(row, col, start, walk_length, num_nodes=3), out)
@pytest.mark.parametrize('device', devices)
def test_rw_large_with_edge_indices(device):
......
......@@ -3,7 +3,7 @@ import os.path as osp
import torch
__version__ = '1.6.2'
__version__ = '1.6.3'
for library in [
'_version', '_grid', '_graclus', '_fps', '_rw', '_sampler', '_nearest',
......
......@@ -3,10 +3,12 @@ from typing import Optional
import torch
@torch.jit.script
def graclus_cluster(row: torch.Tensor, col: torch.Tensor,
def graclus_cluster(
row: torch.Tensor,
col: torch.Tensor,
weight: Optional[torch.Tensor] = None,
num_nodes: Optional[int] = None) -> torch.Tensor:
num_nodes: Optional[int] = None,
) -> torch.Tensor:
"""A greedy clustering algorithm of picking an unmarked vertex and matching
it with one its unmarked neighbors (that maximizes its edge weight).
......
......@@ -3,10 +3,12 @@ from typing import Optional
import torch
@torch.jit.script
def grid_cluster(pos: torch.Tensor, size: torch.Tensor,
def grid_cluster(
pos: torch.Tensor,
size: torch.Tensor,
start: Optional[torch.Tensor] = None,
end: Optional[torch.Tensor] = None) -> torch.Tensor:
end: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""A clustering algorithm, which overlays a regular grid of user-defined
size over a point cloud and clusters all points within a voxel.
......
......@@ -3,7 +3,6 @@ from typing import Optional
import torch
@torch.jit.script
def knn(
x: torch.Tensor,
y: torch.Tensor,
......@@ -83,7 +82,6 @@ def knn(
num_workers)
@torch.jit.script
def knn_graph(
x: torch.Tensor,
k: int,
......
......@@ -3,7 +3,6 @@ from typing import Optional
import torch
@torch.jit.script
def radius(
x: torch.Tensor,
y: torch.Tensor,
......@@ -84,7 +83,6 @@ def radius(
max_num_neighbors, num_workers)
@torch.jit.script
def radius_graph(
x: torch.Tensor,
r: float,
......
......@@ -4,7 +4,6 @@ import torch
from torch import Tensor
@torch.jit.script
def random_walk(
row: Tensor,
col: Tensor,
......@@ -55,8 +54,7 @@ def random_walk(
torch.cumsum(deg, 0, out=rowptr[1:])
node_seq, edge_seq = torch.ops.torch_cluster.random_walk(
rowptr, col, start, walk_length, p, q,
)
rowptr, col, start, walk_length, p, q)
if return_edge_indices:
return node_seq, edge_seq
......
import torch
@torch.jit.script
def neighbor_sampler(start: torch.Tensor, rowptr: torch.Tensor, size: float):
assert not start.is_cuda
......
import torch
WITH_PTR_LIST = hasattr(torch.ops.torch_cluster, 'fps_ptr_list')
try:
WITH_PTR_LIST = hasattr(torch.ops.torch_cluster, 'fps_ptr_list')
except Exception:
WITH_PTR_LIST = False
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment