Commit 882c8e08 authored by rusty1s's avatar rusty1s
Browse files

test jit script

parent 53959eee
......@@ -2,11 +2,17 @@ from itertools import product
import pytest
import torch
from torch import Tensor
from torch_cluster import fps
from .utils import grad_dtypes, devices, tensor
@torch.jit.script
def fps2(x: Tensor, ratio: Tensor) -> Tensor:
return fps(x, None, ratio, False)
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_fps(dtype, device):
x = tensor([
......@@ -47,6 +53,9 @@ def test_fps(dtype, device):
out = fps(x, ratio=torch.tensor([0.5], device=device), random_start=False)
assert out.sort()[0].tolist() == [0, 5, 6, 7]
out = fps2(x, torch.tensor([0.5], device=device))
assert out.sort()[0].tolist() == [0, 5, 6, 7]
@pytest.mark.parametrize('device', devices)
def test_random_fps(device):
......
from typing import Optional
from torch import Tensor
import torch
from torch import Tensor
@torch.jit._overload # noqa
def fps(src, batch, ratio, random_start):
# type: (Tensor, Optional[Tensor], Optional[int], bool) -> Tensor
def fps(src, batch=None, ratio=None, random_start=True):
# type: (Tensor, Optional[Tensor], Optional[float], bool) -> Tensor
pass
@torch.jit._overload # noqa
def fps(src, batch, ratio, random_start):
def fps(src, batch=None, ratio=None, random_start=True):
# type: (Tensor, Optional[Tensor], Optional[Tensor], bool) -> Tensor
pass
def fps(src: torch.Tensor, batch=None, ratio=0.5, random_start=True): # noqa
def fps(src: torch.Tensor, batch=None, ratio=None, random_start=True): # noqa
r""""A sampling algorithm from the `"PointNet++: Deep Hierarchical Feature
Learning on Point Sets in a Metric Space"
<https://arxiv.org/abs/1706.02413>`_ paper, which iteratively samples the
......@@ -45,8 +46,14 @@ def fps(src: torch.Tensor, batch=None, ratio=0.5, random_start=True): # noqa
index = fps(src, batch, ratio=0.5)
"""
if not isinstance(ratio, Tensor):
ratio = torch.tensor(ratio, device=src.device)
r: Optional[Tensor] = None
if ratio is None:
r = torch.tensor(0.5, dtype=src.dtype, device=src.device)
elif isinstance(ratio, float):
r = torch.tensor(ratio, dtype=src.dtype, device=src.device)
else:
r = ratio
assert r is not None
if batch is not None:
assert src.size(0) == batch.numel()
......@@ -60,4 +67,4 @@ def fps(src: torch.Tensor, batch=None, ratio=0.5, random_start=True): # noqa
else:
ptr = torch.tensor([0, src.size(0)], device=src.device)
return torch.ops.torch_cluster.fps(src, ptr, ratio, random_start)
return torch.ops.torch_cluster.fps(src, ptr, r, random_start)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment