Commit 2e57b128 authored by Duc's avatar Duc
Browse files

fixed overloading

parent f868d906
from typing import Optional
from torch import Tensor
import torch
@torch.jit.script
def fps(src: torch.Tensor, batch: Optional[torch.Tensor] = None,
ratio: torch.Tensor = torch.tensor(0.5), random_start: bool = True) -> torch.Tensor:
@torch.jit._overload
def fps(src, batch=None, ratio=None, random_start=True):
# type: (Tensor, Optional[Tensor], Optional[int], bool) -> Tensor
pass
@torch.jit._overload
def fps(src, batch=None, ratio=None, random_start=True):
# type: (Tensor, Optional[Tensor], Optional[Tensor], bool) -> Tensor
pass
def fps(src: torch.Tensor, batch=None, ratio=None, random_start=True):
r""""A sampling algorithm from the `"PointNet++: Deep Hierarchical Feature
Learning on Point Sets in a Metric Space"
<https://arxiv.org/abs/1706.02413>`_ paper, which iteratively samples the
......@@ -33,7 +43,10 @@ def fps(src: torch.Tensor, batch: Optional[torch.Tensor] = None,
index = fps(src, batch, ratio=0.5)
"""
assert len(ratio.shape) < 2, 'Invalid ratio'
if not isinstance(ratio, Tensor):
ratio = torch.tensor(ratio)
assert len(ratio.shape) < 2, f'ratio should be a scalar or a vector, received a tensor rank {len(ratio.shape)}'
ratio = ratio.to(src.device)
if batch is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment