Commit 53959eee authored by rusty1s's avatar rusty1s
Browse files

clean up

parent 5f1939fd
......@@ -13,12 +13,12 @@ torch::Tensor fps_cpu(torch::Tensor src, torch::Tensor ptr, torch::Tensor ratio,
CHECK_CPU(src);
CHECK_CPU(ptr);
CHECK_CPU(ratio);
CHECK_INPUT(ptr.dim() == 1);
// AT_ASSERTM(at::all(at::__and__(at::gt(ratio, 0), at::lt(ratio, 1))), "Invalid input");
src = src.view({src.size(0), -1}).contiguous();
ptr = ptr.contiguous();
auto batch_size = ptr.size(0) - 1;
auto batch_size = ptr.numel() - 1;
auto deg = ptr.narrow(0, 1, batch_size) - ptr.narrow(0, 0, batch_size);
auto out_ptr = deg.toType(torch::kFloat) * ratio;
......
......@@ -3,7 +3,7 @@
#include <ATen/cuda/CUDAContext.h>
#include "utils.cuh"
#include <stdio.h>
#define THREADS 256
template <typename scalar_t>
......@@ -64,19 +64,18 @@ __global__ void fps_kernel(const scalar_t *src, const int64_t *ptr,
}
}
torch::Tensor fps_cuda(torch::Tensor src, torch::Tensor ptr, torch::Tensor ratio,
bool random_start) {
torch::Tensor fps_cuda(torch::Tensor src, torch::Tensor ptr,
torch::Tensor ratio, bool random_start) {
CHECK_CUDA(src);
CHECK_CUDA(ptr);
CHECK_CUDA(ratio);
CHECK_INPUT(ptr.dim() == 1);
// AT_ASSERTM(at::all(at::__and__(at::gt(ratio, 0), at::lt(ratio, 1))), "Invalid input");
cudaSetDevice(src.get_device());
src = src.view({src.size(0), -1}).contiguous();
ptr = ptr.contiguous();
ratio = ratio.contiguous();
auto batch_size = ptr.size(0) - 1;
auto batch_size = ptr.numel() - 1;
auto deg = ptr.narrow(0, 1, batch_size) - ptr.narrow(0, 0, batch_size);
auto out_ptr = deg.toType(torch::kFloat) * ratio;
......
......@@ -2,5 +2,5 @@
#include <torch/extension.h>
torch::Tensor fps_cuda(torch::Tensor src, torch::Tensor ptr, torch::Tensor ratio,
bool random_start);
torch::Tensor fps_cuda(torch::Tensor src, torch::Tensor ptr,
torch::Tensor ratio, bool random_start);
......@@ -98,7 +98,7 @@ setup(
ext_modules=get_extensions() if not BUILD_DOCS else [],
cmdclass={
'build_ext':
BuildExtension.with_options(no_python_abi_suffix=True)
BuildExtension.with_options(no_python_abi_suffix=True, use_ninja=False)
},
packages=find_packages(),
)
......@@ -21,16 +21,30 @@ def test_fps(dtype, device):
], dtype, device)
batch = tensor([0, 0, 0, 0, 1, 1, 1, 1], torch.long, device)
out = fps(x, batch, random_start=False)
assert out.tolist() == [0, 2, 4, 6]
out = fps(x, batch, ratio=0.5, random_start=False)
assert out.tolist() == [0, 2, 4, 6]
out = fps(x, batch, ratio=torch.tensor(0.5), random_start=False)
out = fps(x, batch, ratio=torch.tensor(0.5, device=device),
random_start=False)
assert out.tolist() == [0, 2, 4, 6]
out = fps(x, batch, ratio=torch.tensor([0.5, 0.5]), random_start=False)
out = fps(x, batch, ratio=torch.tensor([0.5, 0.5], device=device),
random_start=False)
assert out.tolist() == [0, 2, 4, 6]
out = fps(x, ratio=torch.tensor(0.5), random_start=False)
out = fps(x, random_start=False)
assert out.sort()[0].tolist() == [0, 5, 6, 7]
out = fps(x, ratio=0.5, random_start=False)
assert out.sort()[0].tolist() == [0, 5, 6, 7]
out = fps(x, ratio=torch.tensor(0.5, device=device), random_start=False)
assert out.sort()[0].tolist() == [0, 5, 6, 7]
out = fps(x, ratio=torch.tensor([0.5], device=device), random_start=False)
assert out.sort()[0].tolist() == [0, 5, 6, 7]
......@@ -42,5 +56,5 @@ def test_random_fps(device):
batch_1 = torch.zeros(N, dtype=torch.long, device=device)
batch_2 = torch.ones(N, dtype=torch.long, device=device)
batch = torch.cat([batch_1, batch_2])
idx = fps(pos, batch, ratio=torch.tensor(0.5))
idx = fps(pos, batch, ratio=0.5)
assert idx.min() >= 0 and idx.max() < 2 * N
......@@ -3,19 +3,19 @@ from torch import Tensor
import torch
@torch.jit._overload
def fps(src, batch=None, ratio=None, random_start=True):
@torch.jit._overload # noqa
def fps(src, batch, ratio, random_start):
# type: (Tensor, Optional[Tensor], Optional[int], bool) -> Tensor
pass
@torch.jit._overload
def fps(src, batch=None, ratio=None, random_start=True):
@torch.jit._overload # noqa
def fps(src, batch, ratio, random_start):
# type: (Tensor, Optional[Tensor], Optional[Tensor], bool) -> Tensor
pass
def fps(src: torch.Tensor, batch=None, ratio=None, random_start=True):
def fps(src: torch.Tensor, batch=None, ratio=0.5, random_start=True): # noqa
r""""A sampling algorithm from the `"PointNet++: Deep Hierarchical Feature
Learning on Point Sets in a Metric Space"
<https://arxiv.org/abs/1706.02413>`_ paper, which iteratively samples the
......@@ -27,12 +27,14 @@ def fps(src: torch.Tensor, batch=None, ratio=None, random_start=True):
batch (LongTensor, optional): Batch vector
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
node to a specific example. (default: :obj:`None`)
ratio (Tensor, optional): Sampling ratio. (default: :obj:`0.5`)
ratio (float or Tensor, optional): Sampling ratio.
(default: :obj:`0.5`)
random_start (bool, optional): If set to :obj:`False`, use the first
node in :math:`\mathbf{X}` as starting node. (default: obj:`True`)
:rtype: :class:`LongTensor`
.. code-block:: python
import torch
......@@ -44,10 +46,7 @@ def fps(src: torch.Tensor, batch=None, ratio=None, random_start=True):
"""
if not isinstance(ratio, Tensor):
ratio = torch.tensor(ratio)
assert len(ratio.shape) < 2, f'ratio should be a scalar or a vector, received a tensor rank {len(ratio.shape)}'
ratio = ratio.to(src.device)
ratio = torch.tensor(ratio, device=src.device)
if batch is not None:
assert src.size(0) == batch.numel()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment