"docs/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "20273e550323b203dae44e8c585fb238294cb892"
Commit 2e57b128 authored by Duc's avatar Duc
Browse files

fixed overloading

parent f868d906
from typing import Optional from typing import Optional
from torch import Tensor
import torch import torch
@torch.jit.script @torch.jit._overload
def fps(src: torch.Tensor, batch: Optional[torch.Tensor] = None, def fps(src, batch=None, ratio=None, random_start=True):
ratio: torch.Tensor = torch.tensor(0.5), random_start: bool = True) -> torch.Tensor: # type: (Tensor, Optional[Tensor], Optional[int], bool) -> Tensor
pass
@torch.jit._overload
def fps(src, batch=None, ratio=None, random_start=True):
# type: (Tensor, Optional[Tensor], Optional[Tensor], bool) -> Tensor
pass
def fps(src: torch.Tensor, batch=None, ratio=None, random_start=True):
r""""A sampling algorithm from the `"PointNet++: Deep Hierarchical Feature r""""A sampling algorithm from the `"PointNet++: Deep Hierarchical Feature
Learning on Point Sets in a Metric Space" Learning on Point Sets in a Metric Space"
<https://arxiv.org/abs/1706.02413>`_ paper, which iteratively samples the <https://arxiv.org/abs/1706.02413>`_ paper, which iteratively samples the
...@@ -33,7 +43,10 @@ def fps(src: torch.Tensor, batch: Optional[torch.Tensor] = None, ...@@ -33,7 +43,10 @@ def fps(src: torch.Tensor, batch: Optional[torch.Tensor] = None,
index = fps(src, batch, ratio=0.5) index = fps(src, batch, ratio=0.5)
""" """
assert len(ratio.shape) < 2, 'Invalid ratio' if not isinstance(ratio, Tensor):
ratio = torch.tensor(ratio)
assert len(ratio.shape) < 2, f'ratio should be a scalar or a vector, received a tensor rank {len(ratio.shape)}'
ratio = ratio.to(src.device) ratio = ratio.to(src.device)
if batch is not None: if batch is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment