"vscode:/vscode.git/clone" did not exist on "abc6c77853ba75dd8509f187456f3513abb302e4"
Commit 882c8e08 authored by rusty1s's avatar rusty1s
Browse files

test jit script

parent 53959eee
...@@ -2,11 +2,17 @@ from itertools import product ...@@ -2,11 +2,17 @@ from itertools import product
import pytest import pytest
import torch import torch
from torch import Tensor
from torch_cluster import fps from torch_cluster import fps
from .utils import grad_dtypes, devices, tensor from .utils import grad_dtypes, devices, tensor
@torch.jit.script
def fps2(x: Tensor, ratio: Tensor) -> Tensor:
return fps(x, None, ratio, False)
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices)) @pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_fps(dtype, device): def test_fps(dtype, device):
x = tensor([ x = tensor([
...@@ -47,6 +53,9 @@ def test_fps(dtype, device): ...@@ -47,6 +53,9 @@ def test_fps(dtype, device):
out = fps(x, ratio=torch.tensor([0.5], device=device), random_start=False) out = fps(x, ratio=torch.tensor([0.5], device=device), random_start=False)
assert out.sort()[0].tolist() == [0, 5, 6, 7] assert out.sort()[0].tolist() == [0, 5, 6, 7]
out = fps2(x, torch.tensor([0.5], device=device))
assert out.sort()[0].tolist() == [0, 5, 6, 7]
@pytest.mark.parametrize('device', devices) @pytest.mark.parametrize('device', devices)
def test_random_fps(device): def test_random_fps(device):
......
from typing import Optional from typing import Optional
from torch import Tensor
import torch import torch
from torch import Tensor
@torch.jit._overload # noqa @torch.jit._overload # noqa
def fps(src, batch, ratio, random_start): def fps(src, batch=None, ratio=None, random_start=True):
# type: (Tensor, Optional[Tensor], Optional[int], bool) -> Tensor # type: (Tensor, Optional[Tensor], Optional[float], bool) -> Tensor
pass pass
@torch.jit._overload # noqa @torch.jit._overload # noqa
def fps(src, batch, ratio, random_start): def fps(src, batch=None, ratio=None, random_start=True):
# type: (Tensor, Optional[Tensor], Optional[Tensor], bool) -> Tensor # type: (Tensor, Optional[Tensor], Optional[Tensor], bool) -> Tensor
pass pass
def fps(src: torch.Tensor, batch=None, ratio=0.5, random_start=True): # noqa def fps(src: torch.Tensor, batch=None, ratio=None, random_start=True): # noqa
r""""A sampling algorithm from the `"PointNet++: Deep Hierarchical Feature r""""A sampling algorithm from the `"PointNet++: Deep Hierarchical Feature
Learning on Point Sets in a Metric Space" Learning on Point Sets in a Metric Space"
<https://arxiv.org/abs/1706.02413>`_ paper, which iteratively samples the <https://arxiv.org/abs/1706.02413>`_ paper, which iteratively samples the
...@@ -45,8 +46,14 @@ def fps(src: torch.Tensor, batch=None, ratio=0.5, random_start=True): # noqa ...@@ -45,8 +46,14 @@ def fps(src: torch.Tensor, batch=None, ratio=0.5, random_start=True): # noqa
index = fps(src, batch, ratio=0.5) index = fps(src, batch, ratio=0.5)
""" """
if not isinstance(ratio, Tensor): r: Optional[Tensor] = None
ratio = torch.tensor(ratio, device=src.device) if ratio is None:
r = torch.tensor(0.5, dtype=src.dtype, device=src.device)
elif isinstance(ratio, float):
r = torch.tensor(ratio, dtype=src.dtype, device=src.device)
else:
r = ratio
assert r is not None
if batch is not None: if batch is not None:
assert src.size(0) == batch.numel() assert src.size(0) == batch.numel()
...@@ -60,4 +67,4 @@ def fps(src: torch.Tensor, batch=None, ratio=0.5, random_start=True): # noqa ...@@ -60,4 +67,4 @@ def fps(src: torch.Tensor, batch=None, ratio=0.5, random_start=True): # noqa
else: else:
ptr = torch.tensor([0, src.size(0)], device=src.device) ptr = torch.tensor([0, src.size(0)], device=src.device)
return torch.ops.torch_cluster.fps(src, ptr, ratio, random_start) return torch.ops.torch_cluster.fps(src, ptr, r, random_start)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment